Move targets & lengths to CPU

This commit is contained in:
Teo 2024-02-02 23:28:23 +09:00
parent 6a6cd82b7a
commit dc238aa4b5

View File

@ -164,9 +164,9 @@ class AsrModel(nn.Module):
ctc_loss = torch.nn.functional.ctc_loss(
log_probs=ctc_output.permute(1, 0, 2), # (T, N, C)
targets=targets.to(torch.long),
input_lengths=encoder_out_lens.to(torch.long),
target_lengths=target_lengths.to(torch.long),
targets=targets.cpu(),
input_lengths=encoder_out_lens.cpu(),
target_lengths=target_lengths.cpu(),
reduction="sum",
)
return ctc_loss