diff --git a/egs/librispeech/ASR/zipformer/optim.py b/egs/librispeech/ASR/zipformer/optim.py index c9b76526c..8ee2b0eb4 100644 --- a/egs/librispeech/ASR/zipformer/optim.py +++ b/egs/librispeech/ASR/zipformer/optim.py @@ -491,6 +491,12 @@ class ScaledAdam(BatchedOptimizer): if self.show_dominant_parameters: assert p.shape[0] == len(param_names) self._show_gradient_dominating_parameter(tuples, tot_sumsq) + if ans != ans: # e.g. ans is nan + ans = 0.0 + if ans == 0.0: + for p, state, param_names in tuples: + p.grad.zero_() # get rid of infinity() + return ans def _show_gradient_dominating_parameter( @@ -573,7 +579,7 @@ class ScaledAdam(BatchedOptimizer): grad = p.grad if clipping_scale != 1.0: - grad = grad * clipping_scale + grad *= clipping_scale step = state["step"] delta = state["delta"]