mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-02 21:54:18 +00:00
Move targets & lengths to CPU
This commit is contained in:
parent
6a6cd82b7a
commit
dc238aa4b5
@ -164,9 +164,9 @@ class AsrModel(nn.Module):
|
||||
|
||||
ctc_loss = torch.nn.functional.ctc_loss(
|
||||
log_probs=ctc_output.permute(1, 0, 2), # (T, N, C)
|
||||
targets=targets.to(torch.long),
|
||||
input_lengths=encoder_out_lens.to(torch.long),
|
||||
target_lengths=target_lengths.to(torch.long),
|
||||
targets=targets.cpu(),
|
||||
input_lengths=encoder_out_lens.cpu(),
|
||||
target_lengths=target_lengths.cpu(),
|
||||
reduction="sum",
|
||||
)
|
||||
return ctc_loss
|
||||
|
Loading…
x
Reference in New Issue
Block a user