mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-05 15:14:18 +00:00
fix wrong order of token slice
This commit is contained in:
parent
ab08201f6c
commit
46605eaef2
@ -481,9 +481,9 @@ def compute_loss(
|
|||||||
with torch.set_grad_enabled(is_training):
|
with torch.set_grad_enabled(is_training):
|
||||||
encoder_out = model.encoder(feature)
|
encoder_out = model.encoder(feature)
|
||||||
text_logits = model.decoder(prev_outputs_tokens.to(device), encoder_out)
|
text_logits = model.decoder(prev_outputs_tokens.to(device), encoder_out)
|
||||||
loss = decoder_criterion(text_logits, target_tokens.to(device))
|
|
||||||
text_logits = text_logits[:, ignore_prefix_size:, :]
|
text_logits = text_logits[:, ignore_prefix_size:, :]
|
||||||
target_tokens = target_tokens[:, ignore_prefix_size:]
|
target_tokens = target_tokens[:, ignore_prefix_size:]
|
||||||
|
loss = decoder_criterion(text_logits, target_tokens.to(device))
|
||||||
|
|
||||||
assert loss.requires_grad == is_training
|
assert loss.requires_grad == is_training
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user