From cc74ba574e341081fa60cfe5f3cdc1fc905a18df Mon Sep 17 00:00:00 2001 From: yaozengwei Date: Thu, 16 Feb 2023 15:27:23 +0800 Subject: [PATCH] disable fp16 when computing ctc loss --- egs/librispeech/ASR/conformer_ctc3/train.py | 27 ++++++++++--------- .../pruned_transducer_stateless4_ctc/train.py | 27 ++++++++++--------- 2 files changed, 28 insertions(+), 26 deletions(-) diff --git a/egs/librispeech/ASR/conformer_ctc3/train.py b/egs/librispeech/ASR/conformer_ctc3/train.py index 2cd223945..e5348e45a 100755 --- a/egs/librispeech/ASR/conformer_ctc3/train.py +++ b/egs/librispeech/ASR/conformer_ctc3/train.py @@ -601,20 +601,21 @@ def compute_loss( else: raise ValueError(f"Unsupported type of graph compiler: {type(graph_compiler)}") - dense_fsa_vec = k2.DenseFsaVec( - nnet_output, - supervision_segments, - allow_truncate=params.subsampling_factor - 1, - ) + with torch.cuda.amp.autocast(enabled=False): + dense_fsa_vec = k2.DenseFsaVec( + nnet_output.float(), + supervision_segments, + allow_truncate=params.subsampling_factor - 1, + ) + ctc_loss = k2.ctc_loss( + decoding_graph=decoding_graph, + dense_fsa_vec=dense_fsa_vec, + output_beam=params.beam_size, + delay_penalty=params.delay_penalty if warmup >= 1.0 else 0.0, + reduction=params.reduction, + use_double_scores=params.use_double_scores, + ) - ctc_loss = k2.ctc_loss( - decoding_graph=decoding_graph, - dense_fsa_vec=dense_fsa_vec, - output_beam=params.beam_size, - delay_penalty=params.delay_penalty if warmup >= 1.0 else 0.0, - reduction=params.reduction, - use_double_scores=params.use_double_scores, - ) ctc_loss_is_finite = torch.isfinite(ctc_loss) if not torch.all(ctc_loss_is_finite): logging.info("Not all losses are finite!\n" f"ctc_loss: {ctc_loss}") diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4_ctc/train.py b/egs/librispeech/ASR/pruned_transducer_stateless4_ctc/train.py index ada93cc0b..e09b4eba5 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless4_ctc/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4_ctc/train.py @@ -710,20 +710,21 @@ def compute_loss( # Works with a BPE model decoding_graph = k2.ctc_graph(token_ids, modified=False, device=device) - dense_fsa_vec = k2.DenseFsaVec( - ctc_output, - supervision_segments, - allow_truncate=params.subsampling_factor - 1, - ) + with torch.cuda.amp.autocast(enabled=False): + dense_fsa_vec = k2.DenseFsaVec( + ctc_output.float(), + supervision_segments, + allow_truncate=params.subsampling_factor - 1, + ) + ctc_loss = k2.ctc_loss( + decoding_graph=decoding_graph, + dense_fsa_vec=dense_fsa_vec, + output_beam=params.beam_size, + delay_penalty=params.ctc_delay_penalty if warmup >= 1.0 else 0.0, + reduction="none", + use_double_scores=params.use_double_scores, + ) - ctc_loss = k2.ctc_loss( - decoding_graph=decoding_graph, - dense_fsa_vec=dense_fsa_vec, - output_beam=params.beam_size, - delay_penalty=params.ctc_delay_penalty if warmup >= 1.0 else 0.0, - reduction="none", - use_double_scores=params.use_double_scores, - ) ctc_loss_is_finite = torch.isfinite(ctc_loss) if not torch.all(ctc_loss_is_finite): logging.info("Not all losses are finite!\n" f"ctc_loss: {ctc_loss}")