diff --git a/egs/librispeech/ASR/.run_v2.sh.swp b/egs/librispeech/ASR/.run_v2.sh.swp index c1ec03e7c..17cc27322 100644 Binary files a/egs/librispeech/ASR/.run_v2.sh.swp and b/egs/librispeech/ASR/.run_v2.sh.swp differ diff --git a/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/.train.py.swp b/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/.train.py.swp index 468ac8578..b5713ff2b 100644 Binary files a/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/.train.py.swp and b/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/.train.py.swp differ diff --git a/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/train.py b/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/train.py index ffc257b56..7d62876a8 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/train.py @@ -719,41 +719,44 @@ def compute_loss( loss = simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss - ''' - # Compute ctc loss + info = MetricsTracker() + + if params.ctc_loss_scale > 0: + # Compute ctc loss - # NOTE: We need `encode_supervisions` to sort sequences with - # different duration in decreasing order, required by - # `k2.intersect_dense` called in `k2.ctc_loss` - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - supervision_segments, token_ids = encode_supervisions( - supervisions, - subsampling_factor=params.subsampling_factor, - token_ids=token_ids, + # NOTE: We need `encode_supervisions` to sort sequences with + # different duration in decreasing order, required by + # `k2.intersect_dense` called in `k2.ctc_loss` + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + supervision_segments, token_ids = encode_supervisions( + supervisions, + subsampling_factor=params.subsampling_factor, + token_ids=token_ids, + ) + + # Works with a BPE model + decoding_graph = k2.ctc_graph(token_ids, modified=False, device=device) + dense_fsa_vec = k2.DenseFsaVec( + ctc_output, + supervision_segments, + allow_truncate=params.subsampling_factor - 1, ) - # Works with a BPE model - decoding_graph = k2.ctc_graph(token_ids, modified=False, device=device) - dense_fsa_vec = k2.DenseFsaVec( - ctc_output, - supervision_segments, - allow_truncate=params.subsampling_factor - 1, - ) - - ctc_loss = k2.ctc_loss( - decoding_graph=decoding_graph, - dense_fsa_vec=dense_fsa_vec, - output_beam=params.beam_size, - reduction="sum", - use_double_scores=params.use_double_scores, - ) - assert ctc_loss.requires_grad == is_training - loss += params.ctc_loss_scale * ctc_loss - ''' + ctc_loss = k2.ctc_loss( + decoding_graph=decoding_graph, + dense_fsa_vec=dense_fsa_vec, + output_beam=params.beam_size, + reduction="sum", + use_double_scores=params.use_double_scores, + ) + assert ctc_loss.requires_grad == is_training + loss += params.ctc_loss_scale * ctc_loss + + info["ctc_loss"] = ctc_loss.detach().cpu().item() + assert loss.requires_grad == is_training - info = MetricsTracker() with warnings.catch_warnings(): warnings.simplefilter("ignore") info["frames"] = (feature_lens // params.subsampling_factor).sum().item() @@ -762,7 +765,6 @@ def compute_loss( info["loss"] = loss.detach().cpu().item() info["simple_loss"] = simple_loss.detach().cpu().item() info["pruned_loss"] = pruned_loss.detach().cpu().item() - #info["ctc_loss"] = ctc_loss.detach().cpu().item() return loss, info