diff --git a/egs/librispeech/ASR/transducer_stateless/joiner.py b/egs/librispeech/ASR/transducer_stateless/joiner.py index 2ef3f1de6..6ee22deec 100644 --- a/egs/librispeech/ASR/transducer_stateless/joiner.py +++ b/egs/librispeech/ASR/transducer_stateless/joiner.py @@ -22,32 +22,50 @@ class Joiner(nn.Module): def __init__(self, input_dim: int, output_dim: int): super().__init__() + self.input_dim = input_dim + self.output_dim = output_dim self.output_linear = nn.Linear(input_dim, output_dim) def forward( - self, encoder_out: torch.Tensor, decoder_out: torch.Tensor + self, + encoder_out: torch.Tensor, + decoder_out: torch.Tensor, + encoder_out_len: torch.Tensor, + decoder_out_len: torch.Tensor, ) -> torch.Tensor: """ Args: encoder_out: - Output from the encoder. Its shape is (N, T, C). + Output from the encoder. Its shape is (N, T, self.input_dim). decoder_out: - Output from the decoder. Its shape is (N, U, C). + Output from the decoder. Its shape is (N, U, self.input_dim). Returns: - Return a tensor of shape (N, T, U, C). + Return a tensor of shape (sum_all_TU, self.output_dim). """ assert encoder_out.ndim == decoder_out.ndim == 3 assert encoder_out.size(0) == decoder_out.size(0) - assert encoder_out.size(2) == decoder_out.size(2) + assert encoder_out.size(2) == self.input_dim + assert decoder_out.size(2) == self.input_dim - encoder_out = encoder_out.unsqueeze(2) - # Now encoder_out is (N, T, 1, C) + N = encoder_out.size(0) - decoder_out = decoder_out.unsqueeze(1) - # Now decoder_out is (N, 1, U, C) + encoder_out_list = [ + encoder_out[i, : encoder_out_len[i], :] for i in range(N) + ] - logit = encoder_out + decoder_out - logit = torch.tanh(logit) + decoder_out_list = [ + decoder_out[i, : decoder_out_len[i], :] for i in range(N) + ] + + x = [ + e.unsqueeze(1) + d.unsqueeze(0) + for e, d in zip(encoder_out_list, decoder_out_list) + ] + + x = [p.reshape(-1, self.input_dim) for p in x] + x = torch.cat(x) + + logit = torch.tanh(x) output = self.output_linear(logit) diff --git a/egs/librispeech/ASR/transducer_stateless/model.py b/egs/librispeech/ASR/transducer_stateless/model.py index 2f0f9a183..98a6f0f37 100644 --- a/egs/librispeech/ASR/transducer_stateless/model.py +++ b/egs/librispeech/ASR/transducer_stateless/model.py @@ -14,15 +14,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -""" -Note we use `rnnt_loss` from torchaudio, which exists only in -torchaudio >= v0.10.0. It also means you have to use torch >= v1.10.0 -""" import k2 import torch import torch.nn as nn -import torchaudio -import torchaudio.functional from encoder_interface import EncoderInterface from icefall.utils import add_sos @@ -102,18 +96,24 @@ class Transducer(nn.Module): decoder_out = self.decoder(sos_y_padded) - logits = self.joiner(encoder_out, decoder_out) + # +1 here since a blank is prepended to each utterance. + logits = self.joiner( + encoder_out=encoder_out, + decoder_out=decoder_out, + encoder_out_len=x_lens, + decoder_out_len=y_lens + 1, + ) # rnnt_loss requires 0 padded targets # Note: y does not start with SOS y_padded = y.pad(mode="constant", padding_value=0) - assert hasattr(torchaudio.functional, "rnnt_loss"), ( - f"Current torchaudio version: {torchaudio.__version__}\n" - "Please install a version >= 0.10.0" - ) + # We don't put this `import` at the beginning of the file + # as it is required only in the training, not during the + # reference stage + import optimized_transducer - loss = torchaudio.functional.rnnt_loss( + loss = optimized_transducer.transducer_loss( logits=logits, targets=y_padded, logit_lengths=x_lens,