diff --git a/icefall/mmi.py b/icefall/mmi.py index b7777b434..2de8f1aae 100644 --- a/icefall/mmi.py +++ b/icefall/mmi.py @@ -124,6 +124,7 @@ def _compute_mmi_loss_exact_non_optimized( den_tot_scores = den_lats.get_tot_scores(log_semiring=True, use_double_scores=True) tot_scores = num_tot_scores - den_scale * den_tot_scores + tot_scores = tot_scores.masked_fill(torch.isinf(tot_scores), 0.0) loss = -1 * tot_scores.sum() return loss