diff --git a/egs/librispeech/ASR/incremental_transf/.identity_train.py.swp b/egs/librispeech/ASR/incremental_transf/.identity_train.py.swp index fd33ee622..65b825032 100644 Binary files a/egs/librispeech/ASR/incremental_transf/.identity_train.py.swp and b/egs/librispeech/ASR/incremental_transf/.identity_train.py.swp differ diff --git a/egs/librispeech/ASR/incremental_transf/identity_train.py b/egs/librispeech/ASR/incremental_transf/identity_train.py index 66e59465c..45a5d8e08 100755 --- a/egs/librispeech/ASR/incremental_transf/identity_train.py +++ b/egs/librispeech/ASR/incremental_transf/identity_train.py @@ -665,6 +665,7 @@ def compute_loss( y = k2.RaggedTensor(y).to(device) with torch.set_grad_enabled(is_training): + ''' simple_loss, pruned_loss = model( x=feature, x_lens=feature_lens, @@ -678,36 +679,18 @@ def compute_loss( ) simple_loss_is_finite = torch.isfinite(simple_loss) pruned_loss_is_finite = torch.isfinite(pruned_loss) - is_finite = simple_loss_is_finite & pruned_loss_is_finite - if not torch.all(is_finite): - logging.info( - "Not all losses are finite!\n" - f"simple_loss: {simple_loss}\n" - f"pruned_loss: {pruned_loss}" - ) - display_and_save_batch(batch, params=params, sp=sp) - simple_loss = simple_loss[simple_loss_is_finite] - pruned_loss = pruned_loss[pruned_loss_is_finite] + ''' + mse_loss = model( + x=feature, + x_lens=feautre_lens, + ) - # If the batch contains more than 10 utterances AND - # if either all simple_loss or pruned_loss is inf or nan, - # we stop the training process by raising an exception - if torch.all(~simple_loss_is_finite) or torch.all(~pruned_loss_is_finite): - raise ValueError( - "There are too many utterances in this batch " - "leading to inf or nan losses." - ) - - simple_loss = simple_loss.sum() - pruned_loss = pruned_loss.sum() # after the main warmup step, we keep pruned_loss_scale small # for the same amount of time (model_warm_step), to avoid # overwhelming the simple_loss and causing it to diverge, # in case it had not fully learned the alignment yet. - pruned_loss_scale = ( - 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) - ) - loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss + #loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss + loss = mse_loss assert loss.requires_grad == is_training