mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-11 19:12:30 +00:00
disable fp16 when computing ctc loss
This commit is contained in:
parent
20207f0e4e
commit
cc74ba574e
@ -601,20 +601,21 @@ def compute_loss(
|
|||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported type of graph compiler: {type(graph_compiler)}")
|
raise ValueError(f"Unsupported type of graph compiler: {type(graph_compiler)}")
|
||||||
|
|
||||||
dense_fsa_vec = k2.DenseFsaVec(
|
with torch.cuda.amp.autocast(enabled=False):
|
||||||
nnet_output,
|
dense_fsa_vec = k2.DenseFsaVec(
|
||||||
supervision_segments,
|
nnet_output.float(),
|
||||||
allow_truncate=params.subsampling_factor - 1,
|
supervision_segments,
|
||||||
)
|
allow_truncate=params.subsampling_factor - 1,
|
||||||
|
)
|
||||||
|
ctc_loss = k2.ctc_loss(
|
||||||
|
decoding_graph=decoding_graph,
|
||||||
|
dense_fsa_vec=dense_fsa_vec,
|
||||||
|
output_beam=params.beam_size,
|
||||||
|
delay_penalty=params.delay_penalty if warmup >= 1.0 else 0.0,
|
||||||
|
reduction=params.reduction,
|
||||||
|
use_double_scores=params.use_double_scores,
|
||||||
|
)
|
||||||
|
|
||||||
ctc_loss = k2.ctc_loss(
|
|
||||||
decoding_graph=decoding_graph,
|
|
||||||
dense_fsa_vec=dense_fsa_vec,
|
|
||||||
output_beam=params.beam_size,
|
|
||||||
delay_penalty=params.delay_penalty if warmup >= 1.0 else 0.0,
|
|
||||||
reduction=params.reduction,
|
|
||||||
use_double_scores=params.use_double_scores,
|
|
||||||
)
|
|
||||||
ctc_loss_is_finite = torch.isfinite(ctc_loss)
|
ctc_loss_is_finite = torch.isfinite(ctc_loss)
|
||||||
if not torch.all(ctc_loss_is_finite):
|
if not torch.all(ctc_loss_is_finite):
|
||||||
logging.info("Not all losses are finite!\n" f"ctc_loss: {ctc_loss}")
|
logging.info("Not all losses are finite!\n" f"ctc_loss: {ctc_loss}")
|
||||||
|
@ -710,20 +710,21 @@ def compute_loss(
|
|||||||
|
|
||||||
# Works with a BPE model
|
# Works with a BPE model
|
||||||
decoding_graph = k2.ctc_graph(token_ids, modified=False, device=device)
|
decoding_graph = k2.ctc_graph(token_ids, modified=False, device=device)
|
||||||
dense_fsa_vec = k2.DenseFsaVec(
|
with torch.cuda.amp.autocast(enabled=False):
|
||||||
ctc_output,
|
dense_fsa_vec = k2.DenseFsaVec(
|
||||||
supervision_segments,
|
ctc_output.float(),
|
||||||
allow_truncate=params.subsampling_factor - 1,
|
supervision_segments,
|
||||||
)
|
allow_truncate=params.subsampling_factor - 1,
|
||||||
|
)
|
||||||
|
ctc_loss = k2.ctc_loss(
|
||||||
|
decoding_graph=decoding_graph,
|
||||||
|
dense_fsa_vec=dense_fsa_vec,
|
||||||
|
output_beam=params.beam_size,
|
||||||
|
delay_penalty=params.ctc_delay_penalty if warmup >= 1.0 else 0.0,
|
||||||
|
reduction="none",
|
||||||
|
use_double_scores=params.use_double_scores,
|
||||||
|
)
|
||||||
|
|
||||||
ctc_loss = k2.ctc_loss(
|
|
||||||
decoding_graph=decoding_graph,
|
|
||||||
dense_fsa_vec=dense_fsa_vec,
|
|
||||||
output_beam=params.beam_size,
|
|
||||||
delay_penalty=params.ctc_delay_penalty if warmup >= 1.0 else 0.0,
|
|
||||||
reduction="none",
|
|
||||||
use_double_scores=params.use_double_scores,
|
|
||||||
)
|
|
||||||
ctc_loss_is_finite = torch.isfinite(ctc_loss)
|
ctc_loss_is_finite = torch.isfinite(ctc_loss)
|
||||||
if not torch.all(ctc_loss_is_finite):
|
if not torch.all(ctc_loss_is_finite):
|
||||||
logging.info("Not all losses are finite!\n" f"ctc_loss: {ctc_loss}")
|
logging.info("Not all losses are finite!\n" f"ctc_loss: {ctc_loss}")
|
||||||
|
Loading…
x
Reference in New Issue
Block a user