diff --git a/egs/librispeech/ASR/zipformer/optim.py b/egs/librispeech/ASR/zipformer/optim.py index 8684adeff..891a4177f 100644 --- a/egs/librispeech/ASR/zipformer/optim.py +++ b/egs/librispeech/ASR/zipformer/optim.py @@ -300,8 +300,8 @@ class ScaledAdam(BatchedOptimizer): # the input is groups of parameter or named parameter. for cur_group in iterable_or_groups: assert "named_params" in cur_group - name_list = [ x[0] for x in cur_group["named_params"] ] - p_list = [ x[1] for x in cur_group["named_params"] ] + name_list = [x[0] for x in cur_group["named_params"]] + p_list = [x[1] for x in cur_group["named_params"]] del cur_group["named_params"] cur_group["params"] = p_list param_groups.append(cur_group) @@ -437,7 +437,9 @@ class ScaledAdam(BatchedOptimizer): "ScaledAdam optimizer does not support sparse gradients" ) if p.numel() == p.shape[0]: # a batch of scalars - tot_sumsq += (grad**2).sum() * (scalar_lr_scale ** 2) # sum() to change shape [1] to [] + tot_sumsq += (grad**2).sum() * ( + scalar_lr_scale**2 + ) # sum() to change shape [1] to [] else: tot_sumsq += ((grad * state["param_rms"]) ** 2).sum() @@ -448,7 +450,9 @@ class ScaledAdam(BatchedOptimizer): ) first_state["model_norms"][step % clipping_update_period] = tot_norm - irregular_estimate_steps = [ i for i in [ 10, 20, 40 ] if i < clipping_update_period ] + irregular_estimate_steps = [ + i for i in [10, 20, 40] if i < clipping_update_period + ] if step % clipping_update_period == 0 or step in irregular_estimate_steps: # Print some stats. # We don't reach here if step == 0 because we would have returned @@ -459,9 +463,7 @@ class ScaledAdam(BatchedOptimizer): num_norms = sorted_norms.numel() quartiles = [] for n in range(0, 5): - index = min( - num_norms - 1, (num_norms // 4) * n - ) + index = min(num_norms - 1, (num_norms // 4) * n) quartiles.append(sorted_norms[index].item()) median = quartiles[2] @@ -499,18 +501,21 @@ class ScaledAdam(BatchedOptimizer): ) if self.show_dominant_parameters: assert p.shape[0] == len(param_names) - self._show_gradient_dominating_parameter(tuples, tot_sumsq, - group["scalar_lr_scale"]) + self._show_gradient_dominating_parameter( + tuples, tot_sumsq, group["scalar_lr_scale"] + ) if ans == 0.0: for (p, state, param_names) in tuples: - p.grad.zero_() # get rid of infinity() + p.grad.zero_() # get rid of infinity() return ans def _show_gradient_dominating_parameter( - self, tuples: List[Tuple[Tensor, dict, List[str]]], tot_sumsq: Tensor, - scalar_lr_scale: float, + self, + tuples: List[Tuple[Tensor, dict, List[str]]], + tot_sumsq: Tensor, + scalar_lr_scale: float, ): """ Show information of parameter which dominates tot_sumsq. @@ -531,11 +536,12 @@ class ScaledAdam(BatchedOptimizer): batch_grad = p.grad if p.numel() == p.shape[0]: # a batch of scalars # Dummy values used by following `zip` statement. - batch_rms_orig = torch.full(p.shape, scalar_lr_scale, - device=batch_grad.device) + batch_rms_orig = torch.full( + p.shape, scalar_lr_scale, device=batch_grad.device + ) else: batch_rms_orig = state["param_rms"] - batch_sumsq_orig = ((batch_grad * batch_rms_orig) ** 2) + batch_sumsq_orig = (batch_grad * batch_rms_orig) ** 2 if batch_grad.ndim > 1: # need to guard it with if-statement because sum() sums over # all dims if dim == (). @@ -679,8 +685,7 @@ class ScaledAdam(BatchedOptimizer): # We have to look at the trained model for parameters at or around the # param_max_rms, because sometimes they can indicate a problem with the # topology or settings. - scale_step = torch.minimum(scale_step, - (param_max_rms - param_rms) / param_rms) + scale_step = torch.minimum(scale_step, (param_max_rms - param_rms) / param_rms) delta = state["delta"] # the factor of (1-beta1) relates to momentum. @@ -897,7 +902,8 @@ class Eden(LRScheduler): warmup_factor = ( 1.0 if self.batch >= self.warmup_batches - else self.warmup_start + (1.0 - self.warmup_start) * (self.batch / self.warmup_batches) + else self.warmup_start + + (1.0 - self.warmup_start) * (self.batch / self.warmup_batches) # else 0.5 + 0.5 * (self.batch / self.warmup_batches) ) @@ -946,7 +952,8 @@ class Eden2(LRScheduler): warmup_factor = ( 1.0 if self.batch >= self.warmup_batches - else self.warmup_start + (1.0 - self.warmup_start) * (self.batch / self.warmup_batches) + else self.warmup_start + + (1.0 - self.warmup_start) * (self.batch / self.warmup_batches) # else 0.5 + 0.5 * (self.batch / self.warmup_batches) )