From 08729f88b1481374c730521c9acb654ee28bd418 Mon Sep 17 00:00:00 2001 From: pkufool Date: Fri, 28 Jan 2022 18:21:15 +0800 Subject: [PATCH] Fix the mismatch of forward & backward joiner label --- egs/librispeech/ASR/transducer_stateless/model.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/transducer_stateless/model.py b/egs/librispeech/ASR/transducer_stateless/model.py index 823bd8fca..1338c4df3 100644 --- a/egs/librispeech/ASR/transducer_stateless/model.py +++ b/egs/librispeech/ASR/transducer_stateless/model.py @@ -18,6 +18,7 @@ import k2 import torch import torch.nn as nn +import torch.nn.functional as F from encoder_interface import EncoderInterface from icefall.utils import add_eos, add_sos @@ -162,6 +163,8 @@ class Transducer(nn.Module): eos_y = add_eos(y, eos_id=blank_id) eos_y_padded = eos_y.pad(mode="constant", padding_value=blank_id) + eos_y_padded = F.pad(eos_y_padded[:, 1:], pad=(0, 1), value=blank_id) + # backward loss assert self.backward_decoder is not None assert self.backward_joiner is not None @@ -174,7 +177,7 @@ class Transducer(nn.Module): ) backward_pruned_loss = k2.rnnt_loss_pruned( backward_logits, - sos_y_padded.to(torch.int64), + y_padded.to(torch.int64), ranges, blank_id, boundary,