diff --git a/egs/librispeech/ASR/transducer_stateless/model.py b/egs/librispeech/ASR/transducer_stateless/model.py index 823bd8fca..1338c4df3 100644 --- a/egs/librispeech/ASR/transducer_stateless/model.py +++ b/egs/librispeech/ASR/transducer_stateless/model.py @@ -18,6 +18,7 @@ import k2 import torch import torch.nn as nn +import torch.nn.functional as F from encoder_interface import EncoderInterface from icefall.utils import add_eos, add_sos @@ -162,6 +163,8 @@ class Transducer(nn.Module): eos_y = add_eos(y, eos_id=blank_id) eos_y_padded = eos_y.pad(mode="constant", padding_value=blank_id) + eos_y_padded = F.pad(eos_y_padded[:, 1:], pad=(0, 1), value=blank_id) + # backward loss assert self.backward_decoder is not None assert self.backward_joiner is not None @@ -174,7 +177,7 @@ class Transducer(nn.Module): ) backward_pruned_loss = k2.rnnt_loss_pruned( backward_logits, - sos_y_padded.to(torch.int64), + y_padded.to(torch.int64), ranges, blank_id, boundary,