fix wrong order of token slice

This commit is contained in:
Yuekai Zhang 2024-01-22 16:24:46 +08:00
parent ab08201f6c
commit 46605eaef2

View File

@ -481,9 +481,9 @@ def compute_loss(
with torch.set_grad_enabled(is_training):
encoder_out = model.encoder(feature)
text_logits = model.decoder(prev_outputs_tokens.to(device), encoder_out)
loss = decoder_criterion(text_logits, target_tokens.to(device))
text_logits = text_logits[:, ignore_prefix_size:, :]
target_tokens = target_tokens[:, ignore_prefix_size:]
loss = decoder_criterion(text_logits, target_tokens.to(device))
assert loss.requires_grad == is_training