Fixing torch.ctc err (#1485)

* fixing torch.ctc err

* Move targets & lengths to CPU
This commit is contained in:
Teo Wen Shen 2024-02-03 07:25:27 +09:00 committed by GitHub
parent b07d5472c5
commit b9e6327adf
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -164,9 +164,9 @@ class AsrModel(nn.Module):
ctc_loss = torch.nn.functional.ctc_loss(
log_probs=ctc_output.permute(1, 0, 2), # (T, N, C)
targets=targets,
input_lengths=encoder_out_lens,
target_lengths=target_lengths,
targets=targets.cpu(),
input_lengths=encoder_out_lens.cpu(),
target_lengths=target_lengths.cpu(),
reduction="sum",
)
return ctc_loss