From 20207f0e4e153334f906d764ba779c511a96dfdb Mon Sep 17 00:00:00 2001 From: yaozengwei Date: Wed, 15 Feb 2023 22:22:21 +0800 Subject: [PATCH] filter inf loss in ctc_loss --- .../pruned_transducer_stateless4_ctc/train.py | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4_ctc/train.py b/egs/librispeech/ASR/pruned_transducer_stateless4_ctc/train.py index 6002923fc..ada93cc0b 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless4_ctc/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4_ctc/train.py @@ -721,12 +721,25 @@ def compute_loss( dense_fsa_vec=dense_fsa_vec, output_beam=params.beam_size, delay_penalty=params.ctc_delay_penalty if warmup >= 1.0 else 0.0, - reduction="sum", + reduction="none", use_double_scores=params.use_double_scores, ) - assert ctc_loss.requires_grad == is_training - loss += params.ctc_loss_scale * ctc_loss + ctc_loss_is_finite = torch.isfinite(ctc_loss) + if not torch.all(ctc_loss_is_finite): + logging.info("Not all losses are finite!\n" f"ctc_loss: {ctc_loss}") + ctc_loss = ctc_loss[ctc_loss_is_finite] + # If either all simple_loss or pruned_loss is inf or nan, + # we stop the training process by raising an exception + if torch.all(~ctc_loss_is_finite): + raise ValueError( + "There are too many utterances in this batch " + "leading to inf or nan losses." + ) + ctc_loss = ctc_loss.sum() + assert ctc_loss.requires_grad == is_training + + loss += params.ctc_loss_scale * ctc_loss assert loss.requires_grad == is_training info = MetricsTracker()