mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-08 09:32:20 +00:00
Fixing torch.ctc err (#1485)
* fixing torch.ctc err * Move targets & lengths to CPU
This commit is contained in:
parent
b07d5472c5
commit
b9e6327adf
@ -164,9 +164,9 @@ class AsrModel(nn.Module):
|
||||
|
||||
ctc_loss = torch.nn.functional.ctc_loss(
|
||||
log_probs=ctc_output.permute(1, 0, 2), # (T, N, C)
|
||||
targets=targets,
|
||||
input_lengths=encoder_out_lens,
|
||||
target_lengths=target_lengths,
|
||||
targets=targets.cpu(),
|
||||
input_lengths=encoder_out_lens.cpu(),
|
||||
target_lengths=target_lengths.cpu(),
|
||||
reduction="sum",
|
||||
)
|
||||
return ctc_loss
|
||||
|
Loading…
x
Reference in New Issue
Block a user