diff --git a/egs/aishell/ASR/whisper/train.py b/egs/aishell/ASR/whisper/train.py index 8d5930437..edea7e7ef 100755 --- a/egs/aishell/ASR/whisper/train.py +++ b/egs/aishell/ASR/whisper/train.py @@ -481,9 +481,9 @@ def compute_loss( with torch.set_grad_enabled(is_training): encoder_out = model.encoder(feature) text_logits = model.decoder(prev_outputs_tokens.to(device), encoder_out) - loss = decoder_criterion(text_logits, target_tokens.to(device)) text_logits = text_logits[:, ignore_prefix_size:, :] target_tokens = target_tokens[:, ignore_prefix_size:] + loss = decoder_criterion(text_logits, target_tokens.to(device)) assert loss.requires_grad == is_training