Fix the mismatch of forward & backward joiner label

This commit is contained in:
pkufool 2022-01-28 18:21:15 +08:00
parent 18f997fe51
commit 08729f88b1

View File

@ -18,6 +18,7 @@
import k2 import k2
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F
from encoder_interface import EncoderInterface from encoder_interface import EncoderInterface
from icefall.utils import add_eos, add_sos from icefall.utils import add_eos, add_sos
@ -162,6 +163,8 @@ class Transducer(nn.Module):
eos_y = add_eos(y, eos_id=blank_id) eos_y = add_eos(y, eos_id=blank_id)
eos_y_padded = eos_y.pad(mode="constant", padding_value=blank_id) eos_y_padded = eos_y.pad(mode="constant", padding_value=blank_id)
eos_y_padded = F.pad(eos_y_padded[:, 1:], pad=(0, 1), value=blank_id)
# backward loss # backward loss
assert self.backward_decoder is not None assert self.backward_decoder is not None
assert self.backward_joiner is not None assert self.backward_joiner is not None
@ -174,7 +177,7 @@ class Transducer(nn.Module):
) )
backward_pruned_loss = k2.rnnt_loss_pruned( backward_pruned_loss = k2.rnnt_loss_pruned(
backward_logits, backward_logits,
sos_y_padded.to(torch.int64), y_padded.to(torch.int64),
ranges, ranges,
blank_id, blank_id,
boundary, boundary,