diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py index d0cc9b866..1493da1fb 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py @@ -852,6 +852,10 @@ def train_one_epoch( cur_grad_scale = scaler._scale.item() if cur_grad_scale < 1.0: scaler.update(cur_grad_scale * 2.0) + if cur_grad_scale < 0.01: + logging.warn("Grad scale is small: {cur_grad_scale}") + if cur_grad_scale < 1.0e-05: + raise RuntimeError("grad_scale is too small, exiting: {cur_grad_scale}") if batch_idx % params.log_interval == 0: cur_lr = scheduler.get_last_lr()[0]