from local

This commit is contained in:
dohe0342 2023-01-09 20:26:14 +09:00
parent ec586c44ce
commit 1e29fe0719
2 changed files with 8 additions and 25 deletions

View File

@ -665,6 +665,7 @@ def compute_loss(
y = k2.RaggedTensor(y).to(device)
with torch.set_grad_enabled(is_training):
'''
simple_loss, pruned_loss = model(
x=feature,
x_lens=feature_lens,
@ -678,36 +679,18 @@ def compute_loss(
)
simple_loss_is_finite = torch.isfinite(simple_loss)
pruned_loss_is_finite = torch.isfinite(pruned_loss)
is_finite = simple_loss_is_finite & pruned_loss_is_finite
if not torch.all(is_finite):
logging.info(
"Not all losses are finite!\n"
f"simple_loss: {simple_loss}\n"
f"pruned_loss: {pruned_loss}"
)
display_and_save_batch(batch, params=params, sp=sp)
simple_loss = simple_loss[simple_loss_is_finite]
pruned_loss = pruned_loss[pruned_loss_is_finite]
'''
mse_loss = model(
x=feature,
x_lens=feautre_lens,
)
# If the batch contains more than 10 utterances AND
# if either all simple_loss or pruned_loss is inf or nan,
# we stop the training process by raising an exception
if torch.all(~simple_loss_is_finite) or torch.all(~pruned_loss_is_finite):
raise ValueError(
"There are too many utterances in this batch "
"leading to inf or nan losses."
)
simple_loss = simple_loss.sum()
pruned_loss = pruned_loss.sum()
# after the main warmup step, we keep pruned_loss_scale small
# for the same amount of time (model_warm_step), to avoid
# overwhelming the simple_loss and causing it to diverge,
# in case it had not fully learned the alignment yet.
pruned_loss_scale = (
0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
)
loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
#loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
loss = mse_loss
assert loss.requires_grad == is_training