diff --git a/egs/librispeech/ASR/README.md b/egs/librispeech/ASR/README.md index b3e90a052..de9d6d50a 100644 --- a/egs/librispeech/ASR/README.md +++ b/egs/librispeech/ASR/README.md @@ -10,13 +10,14 @@ There are various folders containing the name `transducer` in this folder. The following table lists the differences among them. | | Encoder | Decoder | Comment | -|---------------------------------------|---------------------|--------------------|---------------------------------------------------| -| `transducer` | Conformer | LSTM | | -| `transducer_stateless` | Conformer | Embedding + Conv1d | | -| `transducer_lstm` | LSTM | LSTM | | -| `transducer_stateless_multi_datasets` | Conformer | Embedding + Conv1d | Using data from GigaSpeech as extra training data | -| `pruned_transducer_stateless` | Conformer | Embedding + Conv1d | Using k2 pruned RNN-T loss | -| `pruned_transducer_stateless2` | Conformer(modified) | Embedding + Conv1d | Using k2 pruned RNN-T loss | +|---------------------------------------|---------------------|--------------------|-------------------------------------------------------| +| `transducer` | Conformer | LSTM | | +| `transducer_stateless` | Conformer | Embedding + Conv1d | Using optimized_transducer from computing RNN-T loss | +| `transducer_stateless2` | Conformer | Embedding + Conv1d | Using torchaudio for computing RNN-T loss | +| `transducer_lstm` | LSTM | LSTM | | +| `transducer_stateless_multi_datasets` | Conformer | Embedding + Conv1d | Using data from GigaSpeech as extra training data | +| `pruned_transducer_stateless` | Conformer | Embedding + Conv1d | Using k2 pruned RNN-T loss | +| `pruned_transducer_stateless2` | Conformer(modified) | Embedding + Conv1d | Using k2 pruned RNN-T loss | The decoder in `transducer_stateless` is modified from the paper diff --git a/egs/librispeech/ASR/transducer_stateless2/__init__.py b/egs/librispeech/ASR/transducer_stateless2/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/egs/librispeech/ASR/transducer_stateless2/asr_datamodule.py b/egs/librispeech/ASR/transducer_stateless2/asr_datamodule.py new file mode 120000 index 000000000..fa1b8cca3 --- /dev/null +++ b/egs/librispeech/ASR/transducer_stateless2/asr_datamodule.py @@ -0,0 +1 @@ +../tdnn_lstm_ctc/asr_datamodule.py \ No newline at end of file diff --git a/egs/librispeech/ASR/transducer_stateless2/conformer.py b/egs/librispeech/ASR/transducer_stateless2/conformer.py new file mode 120000 index 000000000..70a7ddf11 --- /dev/null +++ b/egs/librispeech/ASR/transducer_stateless2/conformer.py @@ -0,0 +1 @@ +../transducer_stateless/conformer.py \ No newline at end of file diff --git a/egs/librispeech/ASR/transducer_stateless2/decoder.py b/egs/librispeech/ASR/transducer_stateless2/decoder.py new file mode 120000 index 000000000..eada91097 --- /dev/null +++ b/egs/librispeech/ASR/transducer_stateless2/decoder.py @@ -0,0 +1 @@ +../transducer_stateless/decoder.py \ No newline at end of file diff --git a/egs/librispeech/ASR/transducer_stateless2/encoder_interface.py b/egs/librispeech/ASR/transducer_stateless2/encoder_interface.py new file mode 120000 index 000000000..aa5d0217a --- /dev/null +++ b/egs/librispeech/ASR/transducer_stateless2/encoder_interface.py @@ -0,0 +1 @@ +../transducer_stateless/encoder_interface.py \ No newline at end of file diff --git a/egs/librispeech/ASR/transducer_stateless2/joiner.py b/egs/librispeech/ASR/transducer_stateless2/joiner.py new file mode 100644 index 000000000..b0ba7fd83 --- /dev/null +++ b/egs/librispeech/ASR/transducer_stateless2/joiner.py @@ -0,0 +1,81 @@ +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn as nn + + +class Joiner(nn.Module): + def __init__(self, input_dim: int, output_dim: int): + super().__init__() + + self.input_dim = input_dim + self.output_dim = output_dim + self.output_linear = nn.Linear(input_dim, output_dim) + + def forward( + self, + encoder_out: torch.Tensor, + decoder_out: torch.Tensor, + encoder_out_len: torch.Tensor, + decoder_out_len: torch.Tensor, + ) -> torch.Tensor: + """ + Args: + encoder_out: + Output from the encoder. Its shape is (N, T, self.input_dim). + decoder_out: + Output from the decoder. Its shape is (N, U, self.input_dim). + encoder_out_len: + A 1-D tensor of shape (N,) containing valid number of frames + before padding in `encoder_out`. + decoder_out_len: + A 1-D tensor of shape (N,) containing valid number of frames + before padding in `decoder_out`. + Returns: + Return a tensor of shape (sum_all_TU, self.output_dim). + """ + assert encoder_out.ndim == decoder_out.ndim == 3 + assert encoder_out.size(0) == decoder_out.size(0) + assert encoder_out.size(2) == self.input_dim + assert decoder_out.size(2) == self.input_dim + + N = encoder_out.size(0) + + encoder_out_len = encoder_out_len.tolist() + decoder_out_len = decoder_out_len.tolist() + + encoder_out_list = [ + encoder_out[i, : encoder_out_len[i], :] for i in range(N) + ] + + decoder_out_list = [ + decoder_out[i, : decoder_out_len[i], :] for i in range(N) + ] + + x = [ + e.unsqueeze(1) + d.unsqueeze(0) + for e, d in zip(encoder_out_list, decoder_out_list) + ] + + x = [p.reshape(-1, self.input_dim) for p in x] + x = torch.cat(x) + + activations = torch.tanh(x) + + logits = self.output_linear(activations) + + return logits diff --git a/egs/librispeech/ASR/transducer_stateless2/model.py b/egs/librispeech/ASR/transducer_stateless2/model.py new file mode 100644 index 000000000..8281e1fb5 --- /dev/null +++ b/egs/librispeech/ASR/transducer_stateless2/model.py @@ -0,0 +1,143 @@ +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import random + +import k2 +import torch +import torch.nn as nn +from encoder_interface import EncoderInterface + +from icefall.utils import add_sos + + +class Transducer(nn.Module): + """It implements https://arxiv.org/pdf/1211.3711.pdf + "Sequence Transduction with Recurrent Neural Networks" + """ + + def __init__( + self, + encoder: EncoderInterface, + decoder: nn.Module, + joiner: nn.Module, + ): + """ + Args: + encoder: + It is the transcription network in the paper. Its accepts + two inputs: `x` of (N, T, C) and `x_lens` of shape (N,). + It returns two tensors: `logits` of shape (N, T, C) and + `logit_lens` of shape (N,). + decoder: + It is the prediction network in the paper. Its input shape + is (N, U) and its output shape is (N, U, C). It should contain + one attribute: `blank_id`. + joiner: + It has two inputs with shapes: (N, T, C) and (N, U, C). Its + output shape is (N, T, U, C). Note that its output contains + unnormalized probs, i.e., not processed by log-softmax. + """ + super().__init__() + assert isinstance(encoder, EncoderInterface), type(encoder) + assert hasattr(decoder, "blank_id") + + self.encoder = encoder + self.decoder = decoder + self.joiner = joiner + + def forward( + self, + x: torch.Tensor, + x_lens: torch.Tensor, + y: k2.RaggedTensor, + modified_transducer_prob: float = 0.0, + ) -> torch.Tensor: + """ + Args: + x: + A 3-D tensor of shape (N, T, C). + x_lens: + A 1-D tensor of shape (N,). It contains the number of frames in `x` + before padding. + y: + A ragged tensor with 2 axes [utt][label]. It contains labels of each + utterance. + modified_transducer_prob: + The probability to use modified transducer loss. + Returns: + Return the transducer loss. + """ + assert x.ndim == 3, x.shape + assert x_lens.ndim == 1, x_lens.shape + assert y.num_axes == 2, y.num_axes + + assert x.size(0) == x_lens.size(0) == y.dim0 + + encoder_out, x_lens = self.encoder(x, x_lens) + assert torch.all(x_lens > 0) + + # Now for the decoder, i.e., the prediction network + row_splits = y.shape.row_splits(1) + y_lens = row_splits[1:] - row_splits[:-1] + + blank_id = self.decoder.blank_id + sos_y = add_sos(y, sos_id=blank_id) + + sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id) + sos_y_padded = sos_y_padded.to(torch.int64) + + decoder_out = self.decoder(sos_y_padded) + + # +1 here since a blank is prepended to each utterance. + logits = self.joiner( + encoder_out=encoder_out, + decoder_out=decoder_out, + encoder_out_len=x_lens, + decoder_out_len=y_lens + 1, + ) + + # rnnt_loss requires 0 padded targets + # Note: y does not start with SOS + y_padded = y.pad(mode="constant", padding_value=0) + + # We don't put this `import` at the beginning of the file + # as it is required only in the training, not during the + # reference stage + import optimized_transducer + + assert 0 <= modified_transducer_prob <= 1 + + if modified_transducer_prob == 0: + one_sym_per_frame = False + elif random.random() < modified_transducer_prob: + # random.random() returns a float in the range [0, 1) + one_sym_per_frame = True + else: + one_sym_per_frame = False + + loss = optimized_transducer.transducer_loss( + logits=logits, + targets=y_padded, + logit_lengths=x_lens, + target_lengths=y_lens, + blank=blank_id, + reduction="sum", + one_sym_per_frame=one_sym_per_frame, + from_log_softmax=False, + ) + + return loss diff --git a/egs/librispeech/ASR/transducer_stateless2/subsampling.py b/egs/librispeech/ASR/transducer_stateless2/subsampling.py new file mode 120000 index 000000000..af74db6e3 --- /dev/null +++ b/egs/librispeech/ASR/transducer_stateless2/subsampling.py @@ -0,0 +1 @@ +../transducer_stateless/subsampling.py \ No newline at end of file diff --git a/egs/librispeech/ASR/transducer_stateless2/transformer.py b/egs/librispeech/ASR/transducer_stateless2/transformer.py new file mode 120000 index 000000000..e43f520f9 --- /dev/null +++ b/egs/librispeech/ASR/transducer_stateless2/transformer.py @@ -0,0 +1 @@ +../transducer_stateless/transformer.py \ No newline at end of file