diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/model.py b/egs/librispeech/ASR/pruned_transducer_stateless2/model.py index 2434fd41d..452102d21 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/model.py @@ -78,6 +78,7 @@ class Transducer(nn.Module): am_scale: float = 0.0, lm_scale: float = 0.0, warmup: float = 1.0, + reduction: str = "sum", ) -> torch.Tensor: """ Args: @@ -101,6 +102,10 @@ class Transducer(nn.Module): warmup: A value warmup >= 0 that determines which modules are active, values warmup > 1 "are fully warmed up" and all modules will be active. + reduction: + "sum" to sum the losses over all utterances in the batch. + "none" to return the loss in a 1-D tensor for each utterance + in the batch. Returns: Return the transducer loss. @@ -110,6 +115,7 @@ class Transducer(nn.Module): lm_scale * lm_probs + am_scale * am_probs + (1-lm_scale-am_scale) * combined_probs """ + assert reduction in ("sum", "none"), reduction assert x.ndim == 3, x.shape assert x_lens.ndim == 1, x_lens.shape assert y.num_axes == 2, y.num_axes @@ -155,7 +161,7 @@ class Transducer(nn.Module): lm_only_scale=lm_scale, am_only_scale=am_scale, boundary=boundary, - reduction="sum", + reduction=reduction, return_grad=True, ) @@ -188,7 +194,7 @@ class Transducer(nn.Module): ranges=ranges, termination_symbol=blank_id, boundary=boundary, - reduction="sum", + reduction=reduction, ) return (simple_loss, pruned_loss) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/train.py b/egs/librispeech/ASR/pruned_transducer_stateless5/train.py index 3bfe22155..b7ef288c6 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/train.py @@ -655,7 +655,35 @@ def compute_loss( am_scale=params.am_scale, lm_scale=params.lm_scale, warmup=warmup, + reduction="none", ) + simple_loss_is_finite = torch.isfinite(simple_loss) + pruned_loss_is_finite = torch.isfinite(pruned_loss) + is_finite = simple_loss_is_finite & pruned_loss_is_finite + if not torch.all(is_finite): + logging.info( + "Not all losses are finite!\n" + f"simple_loss: {simple_loss}\n" + f"pruned_loss: {pruned_loss}" + ) + display_and_save_batch(batch, params=params, sp=sp) + simple_loss = simple_loss[simple_loss_is_finite] + pruned_loss = pruned_loss[pruned_loss_is_finite] + + # If the batch contains more than 10 utterance AND + # if either all simple_loss or pruned_loss is inf or nan, + # we stop the training process by raising an exception + if feature.size(0) >= 10: + if torch.all(~simple_loss_is_finite) or torch.all( + ~pruned_loss_is_finite + ): + raise ValueError( + "There are too many utterances in this batch " + "leading to inf or nan losses." + ) + + simple_loss = simple_loss.sum() + pruned_loss = pruned_loss.sum() # after the main warmup step, we keep pruned_loss_scale small # for the same amount of time (model_warm_step), to avoid # overwhelming the simple_loss and causing it to diverge, @@ -675,6 +703,10 @@ def compute_loss( info = MetricsTracker() with warnings.catch_warnings(): warnings.simplefilter("ignore") + # info["frames"] is an approximate number for two reasons: + # (1) The acutal subsampling factor is ((lens - 1) // 2 - 1) // 2 + # (2) If some utterances in the batch lead to inf/nan loss, they + # are filtered out. info["frames"] = ( (feature_lens // params.subsampling_factor).sum().item() )