mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
Filter non-finite losses (#525)
* Filter non-finite losses * Fixes after review
This commit is contained in:
parent
951b03f6d7
commit
669401869d
@ -78,6 +78,7 @@ class Transducer(nn.Module):
|
|||||||
am_scale: float = 0.0,
|
am_scale: float = 0.0,
|
||||||
lm_scale: float = 0.0,
|
lm_scale: float = 0.0,
|
||||||
warmup: float = 1.0,
|
warmup: float = 1.0,
|
||||||
|
reduction: str = "sum",
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@ -101,6 +102,10 @@ class Transducer(nn.Module):
|
|||||||
warmup:
|
warmup:
|
||||||
A value warmup >= 0 that determines which modules are active, values
|
A value warmup >= 0 that determines which modules are active, values
|
||||||
warmup > 1 "are fully warmed up" and all modules will be active.
|
warmup > 1 "are fully warmed up" and all modules will be active.
|
||||||
|
reduction:
|
||||||
|
"sum" to sum the losses over all utterances in the batch.
|
||||||
|
"none" to return the loss in a 1-D tensor for each utterance
|
||||||
|
in the batch.
|
||||||
Returns:
|
Returns:
|
||||||
Return the transducer loss.
|
Return the transducer loss.
|
||||||
|
|
||||||
@ -110,6 +115,7 @@ class Transducer(nn.Module):
|
|||||||
lm_scale * lm_probs + am_scale * am_probs +
|
lm_scale * lm_probs + am_scale * am_probs +
|
||||||
(1-lm_scale-am_scale) * combined_probs
|
(1-lm_scale-am_scale) * combined_probs
|
||||||
"""
|
"""
|
||||||
|
assert reduction in ("sum", "none"), reduction
|
||||||
assert x.ndim == 3, x.shape
|
assert x.ndim == 3, x.shape
|
||||||
assert x_lens.ndim == 1, x_lens.shape
|
assert x_lens.ndim == 1, x_lens.shape
|
||||||
assert y.num_axes == 2, y.num_axes
|
assert y.num_axes == 2, y.num_axes
|
||||||
@ -155,7 +161,7 @@ class Transducer(nn.Module):
|
|||||||
lm_only_scale=lm_scale,
|
lm_only_scale=lm_scale,
|
||||||
am_only_scale=am_scale,
|
am_only_scale=am_scale,
|
||||||
boundary=boundary,
|
boundary=boundary,
|
||||||
reduction="sum",
|
reduction=reduction,
|
||||||
return_grad=True,
|
return_grad=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -188,7 +194,7 @@ class Transducer(nn.Module):
|
|||||||
ranges=ranges,
|
ranges=ranges,
|
||||||
termination_symbol=blank_id,
|
termination_symbol=blank_id,
|
||||||
boundary=boundary,
|
boundary=boundary,
|
||||||
reduction="sum",
|
reduction=reduction,
|
||||||
)
|
)
|
||||||
|
|
||||||
return (simple_loss, pruned_loss)
|
return (simple_loss, pruned_loss)
|
||||||
|
@ -655,7 +655,35 @@ def compute_loss(
|
|||||||
am_scale=params.am_scale,
|
am_scale=params.am_scale,
|
||||||
lm_scale=params.lm_scale,
|
lm_scale=params.lm_scale,
|
||||||
warmup=warmup,
|
warmup=warmup,
|
||||||
|
reduction="none",
|
||||||
)
|
)
|
||||||
|
simple_loss_is_finite = torch.isfinite(simple_loss)
|
||||||
|
pruned_loss_is_finite = torch.isfinite(pruned_loss)
|
||||||
|
is_finite = simple_loss_is_finite & pruned_loss_is_finite
|
||||||
|
if not torch.all(is_finite):
|
||||||
|
logging.info(
|
||||||
|
"Not all losses are finite!\n"
|
||||||
|
f"simple_loss: {simple_loss}\n"
|
||||||
|
f"pruned_loss: {pruned_loss}"
|
||||||
|
)
|
||||||
|
display_and_save_batch(batch, params=params, sp=sp)
|
||||||
|
simple_loss = simple_loss[simple_loss_is_finite]
|
||||||
|
pruned_loss = pruned_loss[pruned_loss_is_finite]
|
||||||
|
|
||||||
|
# If the batch contains more than 10 utterance AND
|
||||||
|
# if either all simple_loss or pruned_loss is inf or nan,
|
||||||
|
# we stop the training process by raising an exception
|
||||||
|
if feature.size(0) >= 10:
|
||||||
|
if torch.all(~simple_loss_is_finite) or torch.all(
|
||||||
|
~pruned_loss_is_finite
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
"There are too many utterances in this batch "
|
||||||
|
"leading to inf or nan losses."
|
||||||
|
)
|
||||||
|
|
||||||
|
simple_loss = simple_loss.sum()
|
||||||
|
pruned_loss = pruned_loss.sum()
|
||||||
# after the main warmup step, we keep pruned_loss_scale small
|
# after the main warmup step, we keep pruned_loss_scale small
|
||||||
# for the same amount of time (model_warm_step), to avoid
|
# for the same amount of time (model_warm_step), to avoid
|
||||||
# overwhelming the simple_loss and causing it to diverge,
|
# overwhelming the simple_loss and causing it to diverge,
|
||||||
@ -675,6 +703,10 @@ def compute_loss(
|
|||||||
info = MetricsTracker()
|
info = MetricsTracker()
|
||||||
with warnings.catch_warnings():
|
with warnings.catch_warnings():
|
||||||
warnings.simplefilter("ignore")
|
warnings.simplefilter("ignore")
|
||||||
|
# info["frames"] is an approximate number for two reasons:
|
||||||
|
# (1) The acutal subsampling factor is ((lens - 1) // 2 - 1) // 2
|
||||||
|
# (2) If some utterances in the batch lead to inf/nan loss, they
|
||||||
|
# are filtered out.
|
||||||
info["frames"] = (
|
info["frames"] = (
|
||||||
(feature_lens // params.subsampling_factor).sum().item()
|
(feature_lens // params.subsampling_factor).sum().item()
|
||||||
)
|
)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user