diff --git a/egs/librispeech/ASR/transducer/decoder.py b/egs/librispeech/ASR/transducer/decoder.py new file mode 100644 index 000000000..995afc818 --- /dev/null +++ b/egs/librispeech/ASR/transducer/decoder.py @@ -0,0 +1,98 @@ +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, Tuple + +import torch +import torch.nn as nn + + +# TODO(fangjun): Support switching between LSTM and GRU +class Decoder(nn.Module): + def __init__( + self, + vocab_size: int, + embedding_dim: int, + blank_id: int, + sos_id: int, + num_layers: int, + hidden_dim: int, + embedding_dropout: float = 0.0, + rnn_dropout: float = 0.0, + ): + """ + Args: + vocab_size: + Number of tokens of the modeling unit including blank. + embedding_dim: + Dimension of the input embedding. + blank_id: + The ID of the blank symbol. + sos_id: + The ID of the SOS symbol. + num_layers: + Number of LSTM layers. + hidden_dim: + Hidden dimension of LSTM layers. + embedding_dropout: + Dropout rate for the embedding layer. + rnn_dropout: + Dropout for LSTM layers. + """ + super().__init__() + self.embedding = nn.Embedding( + num_embeddings=vocab_size, + embedding_dim=embedding_dim, + padding_idx=blank_id, + ) + self.embedding_dropout = nn.Dropout(embedding_dropout) + # TODO(fangjun): Use layer normalized LSTM + self.rnn = nn.LSTM( + input_size=embedding_dim, + hidden_size=hidden_dim, + num_layers=num_layers, + batch_first=True, + dropout=rnn_dropout, + ) + self.blank_id = blank_id + self.sos_id = sos_id + self.output_linear = nn.Linear(hidden_dim, hidden_dim) + + def forward( + self, + y: torch.Tensor, + states: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + """ + Args: + y: + A 2-D tensor of shape (N, U) with BOS prepended. + states: + A tuple of two tensors containing the states information of + LSTM layers in this decoder. + Returns: + Return a tuple containing: + + - rnn_output, a tensor of shape (N, U, C) + - (h, c), containing the state information for LSTM layers. + Both are of shape (num_layers, N, C) + """ + embeding_out = self.embedding(y) + embeding_out = self.embedding_dropout(embeding_out) + rnn_out, (h, c) = self.rnn(embeding_out, states) + out = self.output_linear(rnn_out) + + return out, (h, c) diff --git a/egs/librispeech/ASR/transducer/encoder_interface.py b/egs/librispeech/ASR/transducer/encoder_interface.py new file mode 100644 index 000000000..afb0bcfd1 --- /dev/null +++ b/egs/librispeech/ASR/transducer/encoder_interface.py @@ -0,0 +1,53 @@ +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Tuple + +import torch +import torch.nn as nn + + +class EncoderInterface(nn.Module): + def __init__(self, num_features: int, output_dim: int): + """ + Args: + num_features: + The dimension of the input features. + output_dim: + Output dimension of the model. + """ + super().__init__() + + def forward( + self, x: torch.Tensor, x_lens: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + x: + A tensor of shape (batch_size, input_seq_len, num_features) + containing the input features. + x_lens: + A tensor of shape (batch_size,) containing the number of frames + in `x` before padding. + Returns: + Return a tuple containing two tensors: + - encoder_out, a tensor of (batch_size, out_seq_len, output_dim) + containing unnormalized probabilities, i.e., the output of a + linear layer. + - encoder_out_lens, a tensor of shape (batch_size,) containing + the number of frames in `encoder_out` before padding. + """ + raise NotImplementedError("Please implement it in a subclass") diff --git a/egs/librispeech/ASR/transducer/joiner.py b/egs/librispeech/ASR/transducer/joiner.py new file mode 100644 index 000000000..0422f8a6f --- /dev/null +++ b/egs/librispeech/ASR/transducer/joiner.py @@ -0,0 +1,55 @@ +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class Joiner(nn.Module): + def __init__(self, input_dim: int, output_dim: int): + super().__init__() + + self.output_linear = nn.Linear(input_dim, output_dim) + + def forward( + self, encoder_out: torch.Tensor, decoder_out: torch.Tensor + ) -> torch.Tensor: + """ + Args: + encoder_out: + Output from the encoder. Its shape is (N, T, C). + decoder_out: + Output from the decoder. Its shape is (N, U, C). + Returns: + Return a tensor of shape (N, T, U, C). + """ + assert encoder_out.ndim == decoder_out.ndim == 3 + assert encoder_out.size(0) == decoder_out.size(0) + assert encoder_out.size(2) == decoder_out.size(2) + + encoder_out = encoder_out.unsqueeze(2) + # Now encoder_out is (N, T, 1, C) + + decoder_out = decoder_out.unsqueeze(1) + # Now decoder_out is (N, 1, U, C) + + logit = encoder_out + decoder_out + logit = F.relu(logit) + + output = self.output_linear(logit) + + return output diff --git a/egs/librispeech/ASR/transducer/model.py b/egs/librispeech/ASR/transducer/model.py new file mode 100644 index 000000000..d51d5d4ef --- /dev/null +++ b/egs/librispeech/ASR/transducer/model.py @@ -0,0 +1,127 @@ +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Note we use `rnnt_loss` from torchaudio, which exists only in +torchaudio >= v0.10.0. It also means you have to use torch >= v1.10.0 +""" +import k2 +import torch +import torch.nn as nn +import torchaudio +import torchaudio.functional +from transducer.encoder_interface import EncoderInterface + +from icefall.utils import add_sos + +assert hasattr(torchaudio.functional, "rnnt_loss"), ( + f"Current torchaudio version: {torchaudio.__version__}\n" + "Please install a version >= 0.10.0" +) + + +class Transducer(nn.Module): + """It implements https://arxiv.org/pdf/1211.3711.pdf + "Sequence Transduction with Recurrent Neural Networks" + """ + + def __init__( + self, + encoder: EncoderInterface, + decoder: nn.Module, + joiner: nn.Module, + ): + """ + Args: + encoder: + It is the transcription network in the paper. Its accepts + two inputs: `x` of (N, T, C) and `x_lens` of shape (N,). + It returns two tensors: `logits` of shape (N, T, C) and + `logit_lens` of shape (N,). + decoder: + It is the prediction network in the paper. Its input shape + is (N, U) and its output shape is (N, U, C). It should contain + two attributes: `blank_id` and `sos_id`. + joiner: + It has two inputs with shapes: (N, T, C) and (N, U, C). Its + output shape is (N, T, U, C). Note that its output contains + unnormalized probs, i.e., not processed by log-softmax. + """ + super().__init__() + assert isinstance(encoder, EncoderInterface) + assert hasattr(decoder, "blank_id") + assert hasattr(decoder, "sos_id") + + self.encoder = encoder + self.decoder = decoder + self.joiner = joiner + + def forward( + self, + x: torch.Tensor, + x_lens: torch.Tensor, + y: k2.RaggedTensor, + ) -> torch.Tensor: + """ + Args: + x: + A 3-D tensor of shape (N, T, C). + x_lens: + A 1-D tensor of shape (N,). It contains the number of frames in `x` + before padding. + y: + A ragged tensor with 2 axes [utt][label]. It contains labels of each + utterance. + Returns: + Return the transducer loss. + """ + assert x.ndim == 3, x.shape + assert x_lens.ndim == 1, x_lens.shape + assert y.num_axes == 2, y.num_axes + + assert x.size(0) == x_lens.size(0) == y.dim0 + + encoder_out, x_lens = self.encoder(x, x_lens) + assert torch.all(x_lens > 0) + + # Now for the decoder, i.e., the prediction network + row_splits = y.shape.row_splits(1) + y_lens = row_splits[1:] - row_splits[:-1] + + blank_id = self.decoder.blank_id + sos_id = self.decoder.sos_id + sos_y = add_sos(y, sos_id=sos_id) + + sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id) + + decoder_out, _ = self.decoder(sos_y_padded) + + logits = self.joiner(encoder_out, decoder_out) + + # rnnt_loss requires 0 padded targets + # Note: y does not start with SOS + y_padded = y.pad(mode="constant", padding_value=0) + + loss = torchaudio.functional.rnnt_loss( + logits=logits, + targets=y_padded, + logit_lengths=x_lens, + target_lengths=y_lens, + blank=blank_id, + reduction="mean", + ) + + return loss diff --git a/egs/librispeech/ASR/transducer/test_decoder.py b/egs/librispeech/ASR/transducer/test_decoder.py new file mode 100755 index 000000000..a883eb78f --- /dev/null +++ b/egs/librispeech/ASR/transducer/test_decoder.py @@ -0,0 +1,67 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +To run this file, do: + + cd icefall/egs/yesno/ASR + python ./transducer/test_decoder.py +""" + +import torch +from transducer.decoder import Decoder + + +def test_decoder(): + vocab_size = 3 + blank_id = 0 + sos_id = 2 + embedding_dim = 128 + num_layers = 2 + hidden_dim = 6 + N = 3 + U = 5 + + decoder = Decoder( + vocab_size=vocab_size, + embedding_dim=embedding_dim, + blank_id=blank_id, + sos_id=sos_id, + num_layers=num_layers, + hidden_dim=hidden_dim, + embedding_dropout=0.0, + rnn_dropout=0.0, + ) + x = torch.randint(1, vocab_size, (N, U)) + rnn_out, (h, c) = decoder(x) + + assert rnn_out.shape == (N, U, hidden_dim) + assert h.shape == (num_layers, N, hidden_dim) + assert c.shape == (num_layers, N, hidden_dim) + + rnn_out, (h, c) = decoder(x, (h, c)) + assert rnn_out.shape == (N, U, hidden_dim) + assert h.shape == (num_layers, N, hidden_dim) + assert c.shape == (num_layers, N, hidden_dim) + + +def main(): + test_decoder() + + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/transducer/test_joiner.py b/egs/librispeech/ASR/transducer/test_joiner.py new file mode 100755 index 000000000..2773ca319 --- /dev/null +++ b/egs/librispeech/ASR/transducer/test_joiner.py @@ -0,0 +1,50 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +To run this file, do: + + cd icefall/egs/yesno/ASR + python ./transducer/test_joiner.py +""" + + +import torch +from transducer.joiner import Joiner + + +def test_joiner(): + N = 2 + T = 3 + C = 4 + U = 5 + + joiner = Joiner(C, 10) + + encoder_out = torch.rand(N, T, C) + decoder_out = torch.rand(N, U, C) + + joint = joiner(encoder_out, decoder_out) + assert joint.shape == (N, T, U, 10) + + +def main(): + test_joiner() + + +if __name__ == "__main__": + main()