mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
from local
This commit is contained in:
parent
ec586c44ce
commit
1e29fe0719
Binary file not shown.
@ -665,6 +665,7 @@ def compute_loss(
|
||||
y = k2.RaggedTensor(y).to(device)
|
||||
|
||||
with torch.set_grad_enabled(is_training):
|
||||
'''
|
||||
simple_loss, pruned_loss = model(
|
||||
x=feature,
|
||||
x_lens=feature_lens,
|
||||
@ -678,36 +679,18 @@ def compute_loss(
|
||||
)
|
||||
simple_loss_is_finite = torch.isfinite(simple_loss)
|
||||
pruned_loss_is_finite = torch.isfinite(pruned_loss)
|
||||
is_finite = simple_loss_is_finite & pruned_loss_is_finite
|
||||
if not torch.all(is_finite):
|
||||
logging.info(
|
||||
"Not all losses are finite!\n"
|
||||
f"simple_loss: {simple_loss}\n"
|
||||
f"pruned_loss: {pruned_loss}"
|
||||
)
|
||||
display_and_save_batch(batch, params=params, sp=sp)
|
||||
simple_loss = simple_loss[simple_loss_is_finite]
|
||||
pruned_loss = pruned_loss[pruned_loss_is_finite]
|
||||
'''
|
||||
mse_loss = model(
|
||||
x=feature,
|
||||
x_lens=feautre_lens,
|
||||
)
|
||||
|
||||
# If the batch contains more than 10 utterances AND
|
||||
# if either all simple_loss or pruned_loss is inf or nan,
|
||||
# we stop the training process by raising an exception
|
||||
if torch.all(~simple_loss_is_finite) or torch.all(~pruned_loss_is_finite):
|
||||
raise ValueError(
|
||||
"There are too many utterances in this batch "
|
||||
"leading to inf or nan losses."
|
||||
)
|
||||
|
||||
simple_loss = simple_loss.sum()
|
||||
pruned_loss = pruned_loss.sum()
|
||||
# after the main warmup step, we keep pruned_loss_scale small
|
||||
# for the same amount of time (model_warm_step), to avoid
|
||||
# overwhelming the simple_loss and causing it to diverge,
|
||||
# in case it had not fully learned the alignment yet.
|
||||
pruned_loss_scale = (
|
||||
0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
|
||||
)
|
||||
loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
|
||||
#loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
|
||||
loss = mse_loss
|
||||
|
||||
assert loss.requires_grad == is_training
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user