From 46605eaef2f82df889b2000287142cc835190f48 Mon Sep 17 00:00:00 2001 From: Yuekai Zhang Date: Mon, 22 Jan 2024 16:24:46 +0800 Subject: [PATCH] fix wrong order of token slice --- egs/aishell/ASR/whisper/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/aishell/ASR/whisper/train.py b/egs/aishell/ASR/whisper/train.py index 8d5930437..edea7e7ef 100755 --- a/egs/aishell/ASR/whisper/train.py +++ b/egs/aishell/ASR/whisper/train.py @@ -481,9 +481,9 @@ def compute_loss( with torch.set_grad_enabled(is_training): encoder_out = model.encoder(feature) text_logits = model.decoder(prev_outputs_tokens.to(device), encoder_out) - loss = decoder_criterion(text_logits, target_tokens.to(device)) text_logits = text_logits[:, ignore_prefix_size:, :] target_tokens = target_tokens[:, ignore_prefix_size:] + loss = decoder_criterion(text_logits, target_tokens.to(device)) assert loss.requires_grad == is_training