mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-03 22:24:19 +00:00
fix wrong order of token slice
This commit is contained in:
parent
ab08201f6c
commit
46605eaef2
@ -481,9 +481,9 @@ def compute_loss(
|
||||
with torch.set_grad_enabled(is_training):
|
||||
encoder_out = model.encoder(feature)
|
||||
text_logits = model.decoder(prev_outputs_tokens.to(device), encoder_out)
|
||||
loss = decoder_criterion(text_logits, target_tokens.to(device))
|
||||
text_logits = text_logits[:, ignore_prefix_size:, :]
|
||||
target_tokens = target_tokens[:, ignore_prefix_size:]
|
||||
loss = decoder_criterion(text_logits, target_tokens.to(device))
|
||||
|
||||
assert loss.requires_grad == is_training
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user