From 9e5a5d7839aa3052e46dcf25b239a37449f8cd5e Mon Sep 17 00:00:00 2001 From: zr_jin Date: Thu, 2 Nov 2023 16:10:08 +0800 Subject: [PATCH] Incorporate some latest changes to `optim.py` (#1359) * init commit * black formatted * isort formatted --- egs/librispeech/ASR/zipformer/optim.py | 171 +++++++++++++++++-------- 1 file changed, 121 insertions(+), 50 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/optim.py b/egs/librispeech/ASR/zipformer/optim.py index 8ee2b0eb4..a663db708 100644 --- a/egs/librispeech/ASR/zipformer/optim.py +++ b/egs/librispeech/ASR/zipformer/optim.py @@ -22,7 +22,7 @@ from typing import Dict, List, Optional, Tuple, Union import torch from lhotse.utils import fix_random_seed -from torch import Tensor +from torch import Tensor, nn from torch.optim import Optimizer @@ -116,7 +116,7 @@ class BatchedOptimizer(Optimizer): yield tuples # <-- calling code will do the actual optimization here! - for (stacked_params, _state, _names), batch in zip(tuples, batches): + for ((stacked_params, _state, _names), batch) in zip(tuples, batches): for i, p in enumerate(batch): # batch is list of Parameter p.copy_(stacked_params[i]) @@ -181,6 +181,7 @@ class ScaledAdam(BatchedOptimizer): size_update_period=4, clipping_update_period=100, ): + defaults = dict( lr=lr, clipping_scale=clipping_scale, @@ -326,7 +327,9 @@ class ScaledAdam(BatchedOptimizer): batch = True for group, group_params_names in zip(self.param_groups, self.parameters_names): + with self.batched_params(group["params"], group_params_names) as batches: + # batches is list of pairs (stacked_param, state). stacked_param is like # a regular parameter, and will have a .grad, but the 1st dim corresponds to # a stacking dim, it is not a real dim. @@ -423,16 +426,19 @@ class ScaledAdam(BatchedOptimizer): # parameters' state won't have been initialized yet. return 1.0 clipping_update_period = group["clipping_update_period"] + scalar_lr_scale = group["scalar_lr_scale"] tot_sumsq = torch.tensor(0.0, device=first_p.device) - for p, state, param_names in tuples: + for (p, state, param_names) in tuples: grad = p.grad if grad.is_sparse: raise RuntimeError( "ScaledAdam optimizer does not support sparse gradients" ) if p.numel() == p.shape[0]: # a batch of scalars - tot_sumsq += (grad**2).sum() # sum() to change shape [1] to [] + tot_sumsq += (grad**2).sum() * ( + scalar_lr_scale**2 + ) # sum() to change shape [1] to [] else: tot_sumsq += ((grad * state["param_rms"]) ** 2).sum() @@ -443,64 +449,72 @@ class ScaledAdam(BatchedOptimizer): ) first_state["model_norms"][step % clipping_update_period] = tot_norm - if step % clipping_update_period == 0: + irregular_estimate_steps = [ + i for i in [10, 20, 40] if i < clipping_update_period + ] + if step % clipping_update_period == 0 or step in irregular_estimate_steps: # Print some stats. # We don't reach here if step == 0 because we would have returned # above. sorted_norms = first_state["model_norms"].sort()[0].to("cpu") + if step in irregular_estimate_steps: + sorted_norms = sorted_norms[-step:] + num_norms = sorted_norms.numel() quartiles = [] for n in range(0, 5): - index = min( - clipping_update_period - 1, (clipping_update_period // 4) * n - ) + index = min(num_norms - 1, (num_norms // 4) * n) quartiles.append(sorted_norms[index].item()) median = quartiles[2] threshold = clipping_scale * median + if step in irregular_estimate_steps: + # use larger thresholds on first few steps of estimating threshold, + # as norm may be changing rapidly. + threshold = threshold * 2.0 first_state["model_norm_threshold"] = threshold percent_clipped = ( - first_state["num_clipped"] * 100.0 / clipping_update_period + first_state["num_clipped"] * 100.0 / num_norms if "num_clipped" in first_state else 0.0 ) first_state["num_clipped"] = 0 quartiles = " ".join(["%.3e" % x for x in quartiles]) - logging.info( + logging.warn( f"Clipping_scale={clipping_scale}, grad-norm quartiles {quartiles}, " f"threshold={threshold:.3e}, percent-clipped={percent_clipped:.1f}" ) - if step < clipping_update_period: - return 1.0 # We have not yet estimated a norm to clip to. - else: - try: - model_norm_threshold = first_state["model_norm_threshold"] - except KeyError: - logging.info( - "Warning: model_norm_threshold not in state: possibly " - "you changed config when restarting, adding clipping_scale option?" - ) - return 1.0 - ans = min(1.0, (model_norm_threshold / (tot_norm + 1.0e-20)).item()) - if ans < 1.0: - first_state["num_clipped"] += 1 - if ans < 0.1: - logging.warn( - f"Scaling gradients by {ans}, model_norm_threshold={model_norm_threshold}" - ) - if self.show_dominant_parameters: - assert p.shape[0] == len(param_names) - self._show_gradient_dominating_parameter(tuples, tot_sumsq) - if ans != ans: # e.g. ans is nan - ans = 0.0 - if ans == 0.0: - for p, state, param_names in tuples: - p.grad.zero_() # get rid of infinity() + try: + model_norm_threshold = first_state["model_norm_threshold"] + except KeyError: + return 1.0 # threshold has not yet been set. - return ans + ans = min(1.0, (model_norm_threshold / (tot_norm + 1.0e-20)).item()) + if ans != ans: # e.g. ans is nan + ans = 0.0 + if ans < 1.0: + first_state["num_clipped"] += 1 + if ans < 0.1: + logging.warn( + f"Scaling gradients by {ans}, model_norm_threshold={model_norm_threshold}" + ) + if self.show_dominant_parameters: + assert p.shape[0] == len(param_names) + self._show_gradient_dominating_parameter( + tuples, tot_sumsq, group["scalar_lr_scale"] + ) + + if ans == 0.0: + for (p, state, param_names) in tuples: + p.grad.zero_() # get rid of infinity() + + return ans def _show_gradient_dominating_parameter( - self, tuples: List[Tuple[Tensor, dict, List[str]]], tot_sumsq: Tensor + self, + tuples: List[Tuple[Tensor, dict, List[str]]], + tot_sumsq: Tensor, + scalar_lr_scale: float, ): """ Show information of parameter which dominates tot_sumsq. @@ -516,29 +530,30 @@ class ScaledAdam(BatchedOptimizer): from tuples, we still pass it to save some time. """ all_sumsq_orig = {} - for p, state, batch_param_names in tuples: + for (p, state, batch_param_names) in tuples: # p is a stacked batch parameters. batch_grad = p.grad if p.numel() == p.shape[0]: # a batch of scalars - batch_sumsq_orig = batch_grad**2 # Dummy values used by following `zip` statement. - batch_rms_orig = torch.ones(p.shape[0]) + batch_rms_orig = torch.full( + p.shape, scalar_lr_scale, device=batch_grad.device + ) else: batch_rms_orig = state["param_rms"] - batch_sumsq_orig = ((batch_grad * batch_rms_orig) ** 2).sum( + batch_sumsq_orig = (batch_grad * batch_rms_orig) ** 2 + if batch_grad.ndim > 1: + # need to guard it with if-statement because sum() sums over + # all dims if dim == (). + batch_sumsq_orig = batch_sumsq_orig.sum( dim=list(range(1, batch_grad.ndim)) ) - for name, sumsq_orig, rms, grad in zip( batch_param_names, batch_sumsq_orig, batch_rms_orig, batch_grad ): + proportion_orig = sumsq_orig / tot_sumsq all_sumsq_orig[name] = (proportion_orig, sumsq_orig, rms, grad) - assert torch.isclose( - sum([value[0] for value in all_sumsq_orig.values()]).cpu(), - torch.tensor(1.0), - ) sorted_by_proportion = { k: v for k, v in sorted( @@ -552,7 +567,7 @@ class ScaledAdam(BatchedOptimizer): dominant_rms, dominant_grad, ) = sorted_by_proportion[dominant_param_name] - logging.info( + logging.warn( f"Parameter dominating tot_sumsq {dominant_param_name}" f" with proportion {dominant_proportion:.2f}," f" where dominant_sumsq=(grad_sumsq*orig_rms_sq)" @@ -826,7 +841,7 @@ class LRScheduler(object): def print_lr(self, is_verbose, group, lr): """Display the current learning rate.""" if is_verbose: - logging.info( + logging.warn( f"Epoch={self.epoch}, batch={self.batch}: adjusting learning rate" f" of group {group} to {lr:.4e}." ) @@ -841,8 +856,14 @@ class Eden(LRScheduler): where `warmup` increases from linearly 0.5 to 1 over `warmup_batches` batches and then stays constant at 1. + If you don't have the concept of epochs, or one epoch takes a very long time, + you can replace the notion of 'epoch' with some measure of the amount of data + processed, e.g. hours of data or frames of data, with 'lr_epochs' being set to + some measure representing "quite a lot of data": say, one fifth or one third + of an entire training run, but it doesn't matter much. You could also use + Eden2 which has only the notion of batches. - E.g. suggest base_lr = 0.04 (passed to optimizer) if used with ScaledAdam + We suggest base_lr = 0.04 (passed to optimizer) if used with ScaledAdam Args: optimizer: the optimizer to change the learning rates on @@ -888,6 +909,56 @@ class Eden(LRScheduler): return [x * factor * warmup_factor for x in self.base_lrs] +class Eden2(LRScheduler): + """ + Eden2 scheduler, simpler than Eden because it does not use the notion of epoch, + only batches. + + The basic formula (before warmup) is: + lr = base_lr * ((batch**2 + lr_batches**2) / lr_batches**2) ** -0.5) * warmup + + where `warmup` increases from linearly 0.5 to 1 over `warmup_batches` batches + and then stays constant at 1. + + + E.g. suggest base_lr = 0.04 (passed to optimizer) if used with ScaledAdam + + Args: + optimizer: the optimizer to change the learning rates on + lr_batches: the number of batches after which we start significantly + decreasing the learning rate, suggest 5000. + """ + + def __init__( + self, + optimizer: Optimizer, + lr_batches: Union[int, float], + warmup_batches: Union[int, float] = 500.0, + warmup_start: float = 0.5, + verbose: bool = False, + ): + super().__init__(optimizer, verbose) + self.lr_batches = lr_batches + self.warmup_batches = warmup_batches + + assert 0.0 <= warmup_start <= 1.0, warmup_start + self.warmup_start = warmup_start + + def get_lr(self): + factor = ( + (self.batch**2 + self.lr_batches**2) / self.lr_batches**2 + ) ** -0.5 + warmup_factor = ( + 1.0 + if self.batch >= self.warmup_batches + else self.warmup_start + + (1.0 - self.warmup_start) * (self.batch / self.warmup_batches) + # else 0.5 + 0.5 * (self.batch / self.warmup_batches) + ) + + return [x * factor * warmup_factor for x in self.base_lrs] + + def _test_eden(): m = torch.nn.Linear(100, 100) optim = ScaledAdam(m.parameters(), lr=0.03)