disable fp16 when computing ctc loss

This commit is contained in:
yaozengwei 2023-02-16 15:27:23 +08:00
parent 20207f0e4e
commit cc74ba574e
2 changed files with 28 additions and 26 deletions

View File

@ -601,20 +601,21 @@ def compute_loss(
else: else:
raise ValueError(f"Unsupported type of graph compiler: {type(graph_compiler)}") raise ValueError(f"Unsupported type of graph compiler: {type(graph_compiler)}")
dense_fsa_vec = k2.DenseFsaVec( with torch.cuda.amp.autocast(enabled=False):
nnet_output, dense_fsa_vec = k2.DenseFsaVec(
supervision_segments, nnet_output.float(),
allow_truncate=params.subsampling_factor - 1, supervision_segments,
) allow_truncate=params.subsampling_factor - 1,
)
ctc_loss = k2.ctc_loss(
decoding_graph=decoding_graph,
dense_fsa_vec=dense_fsa_vec,
output_beam=params.beam_size,
delay_penalty=params.delay_penalty if warmup >= 1.0 else 0.0,
reduction=params.reduction,
use_double_scores=params.use_double_scores,
)
ctc_loss = k2.ctc_loss(
decoding_graph=decoding_graph,
dense_fsa_vec=dense_fsa_vec,
output_beam=params.beam_size,
delay_penalty=params.delay_penalty if warmup >= 1.0 else 0.0,
reduction=params.reduction,
use_double_scores=params.use_double_scores,
)
ctc_loss_is_finite = torch.isfinite(ctc_loss) ctc_loss_is_finite = torch.isfinite(ctc_loss)
if not torch.all(ctc_loss_is_finite): if not torch.all(ctc_loss_is_finite):
logging.info("Not all losses are finite!\n" f"ctc_loss: {ctc_loss}") logging.info("Not all losses are finite!\n" f"ctc_loss: {ctc_loss}")

View File

@ -710,20 +710,21 @@ def compute_loss(
# Works with a BPE model # Works with a BPE model
decoding_graph = k2.ctc_graph(token_ids, modified=False, device=device) decoding_graph = k2.ctc_graph(token_ids, modified=False, device=device)
dense_fsa_vec = k2.DenseFsaVec( with torch.cuda.amp.autocast(enabled=False):
ctc_output, dense_fsa_vec = k2.DenseFsaVec(
supervision_segments, ctc_output.float(),
allow_truncate=params.subsampling_factor - 1, supervision_segments,
) allow_truncate=params.subsampling_factor - 1,
)
ctc_loss = k2.ctc_loss(
decoding_graph=decoding_graph,
dense_fsa_vec=dense_fsa_vec,
output_beam=params.beam_size,
delay_penalty=params.ctc_delay_penalty if warmup >= 1.0 else 0.0,
reduction="none",
use_double_scores=params.use_double_scores,
)
ctc_loss = k2.ctc_loss(
decoding_graph=decoding_graph,
dense_fsa_vec=dense_fsa_vec,
output_beam=params.beam_size,
delay_penalty=params.ctc_delay_penalty if warmup >= 1.0 else 0.0,
reduction="none",
use_double_scores=params.use_double_scores,
)
ctc_loss_is_finite = torch.isfinite(ctc_loss) ctc_loss_is_finite = torch.isfinite(ctc_loss)
if not torch.all(ctc_loss_is_finite): if not torch.all(ctc_loss_is_finite):
logging.info("Not all losses are finite!\n" f"ctc_loss: {ctc_loss}") logging.info("Not all losses are finite!\n" f"ctc_loss: {ctc_loss}")