diff --git a/egs/librispeech/ASR/pruned_transducer_stateless_multi_datasets/decoder.py b/egs/librispeech/ASR/pruned_transducer_stateless_multi_datasets/decoder.py deleted file mode 100644 index 3d4e69a4b..000000000 --- a/egs/librispeech/ASR/pruned_transducer_stateless_multi_datasets/decoder.py +++ /dev/null @@ -1,100 +0,0 @@ -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import torch -import torch.nn as nn -import torch.nn.functional as F - - -class Decoder(nn.Module): - """This class modifies the stateless decoder from the following paper: - - RNN-transducer with stateless prediction network - https://ieeexplore.ieee.org/stamp/stamp.jsp?arnumber=9054419 - - It removes the recurrent connection from the decoder, i.e., the prediction - network. Different from the above paper, it adds an extra Conv1d - right after the embedding layer. - - TODO: Implement https://arxiv.org/pdf/2109.07513.pdf - """ - - def __init__( - self, - vocab_size: int, - embedding_dim: int, - blank_id: int, - context_size: int, - ): - """ - Args: - vocab_size: - Number of tokens of the modeling unit including blank. - embedding_dim: - Dimension of the input embedding. - blank_id: - The ID of the blank symbol. - context_size: - Number of previous words to use to predict the next word. - 1 means bigram; 2 means trigram. n means (n+1)-gram. - """ - super().__init__() - self.embedding = nn.Embedding( - num_embeddings=vocab_size, - embedding_dim=embedding_dim, - padding_idx=blank_id, - ) - self.blank_id = blank_id - - assert context_size >= 1, context_size - self.context_size = context_size - if context_size > 1: - self.conv = nn.Conv1d( - in_channels=embedding_dim, - out_channels=embedding_dim, - kernel_size=context_size, - padding=0, - groups=embedding_dim, - bias=False, - ) - self.output_linear = nn.Linear(embedding_dim, vocab_size) - - def forward(self, y: torch.Tensor, need_pad: bool = True) -> torch.Tensor: - """ - Args: - y: - A 2-D tensor of shape (N, U) with blank prepended. - need_pad: - True to left pad the input. Should be True during training. - False to not pad the input. Should be False during inference. - Returns: - Return a tensor of shape (N, U, embedding_dim). - """ - embedding_out = self.embedding(y) - if self.context_size > 1: - embedding_out = embedding_out.permute(0, 2, 1) - if need_pad is True: - embedding_out = F.pad( - embedding_out, pad=(self.context_size - 1, 0) - ) - else: - # During inference time, there is no need to do extra padding - # as we only need one output - assert embedding_out.size(-1) == self.context_size - embedding_out = self.conv(embedding_out) - embedding_out = embedding_out.permute(0, 2, 1) - embedding_out = self.output_linear(F.relu(embedding_out)) - return embedding_out diff --git a/egs/librispeech/ASR/pruned_transducer_stateless_multi_datasets/decoder.py b/egs/librispeech/ASR/pruned_transducer_stateless_multi_datasets/decoder.py new file mode 120000 index 000000000..eada91097 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless_multi_datasets/decoder.py @@ -0,0 +1 @@ +../transducer_stateless/decoder.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless_multi_datasets/joiner.py b/egs/librispeech/ASR/pruned_transducer_stateless_multi_datasets/joiner.py index 7c5a93a86..3342f8480 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless_multi_datasets/joiner.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless_multi_datasets/joiner.py @@ -20,11 +20,20 @@ import torch.nn.functional as F class Joiner(nn.Module): - def __init__(self, input_dim: int, inner_dim: int, output_dim: int): + def __init__(self, input_dim: int, output_dim: int): + """ + Args: + input_dim: + Input dim of the joiner. It should be equal + to the output dim of the encoder and decoder. + output_dim: + Output dim of the joiner. It should be equal + to the vocab_size. + """ super().__init__() - - self.inner_linear = nn.Linear(input_dim, inner_dim) - self.output_linear = nn.Linear(inner_dim, output_dim) + self.input_dim = input_dim + self.output_dim = output_dim + self.output_linear = nn.Linear(input_dim, output_dim) def forward( self, encoder_out: torch.Tensor, decoder_out: torch.Tensor @@ -40,11 +49,10 @@ class Joiner(nn.Module): """ assert encoder_out.ndim == decoder_out.ndim == 4 assert encoder_out.shape == decoder_out.shape + assert encoder_out.size(-1) == self.input_dim - logit = encoder_out + decoder_out + x = encoder_out + decoder_out + activations = torch.tanh(x) + logits = self.output_linear(activations) - logit = self.inner_linear(torch.tanh(logit)) - - output = self.output_linear(F.relu(logit)) - - return output + return logits diff --git a/egs/librispeech/ASR/pruned_transducer_stateless_multi_datasets/model.py b/egs/librispeech/ASR/pruned_transducer_stateless_multi_datasets/model.py index 2f019bcdb..ef0a9648c 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless_multi_datasets/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless_multi_datasets/model.py @@ -46,8 +46,8 @@ class Transducer(nn.Module): is (N, U) and its output shape is (N, U, C). It should contain one attribute: `blank_id`. joiner: - It has two inputs with shapes: (N, T, C) and (N, U, C). Its - output shape is (N, T, U, C). Note that its output contains + It has two inputs with shapes: (N, T, U, C) and (N, T, U, C). Its + output shape is also (N, T, U, C). Note that its output contains unnormalized probs, i.e., not processed by log-softmax. """ super().__init__() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless_multi_datasets/train.py b/egs/librispeech/ASR/pruned_transducer_stateless_multi_datasets/train.py index 25f6c3608..9bfbabf48 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless_multi_datasets/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless_multi_datasets/train.py @@ -246,6 +246,7 @@ def get_params() -> AttributeDict: "log_diagnostics": False, # parameters for conformer "feature_dim": 80, + "encoder_out_dim": 512, "subsampling_factor": 4, "attention_dim": 512, "nhead": 8, @@ -267,7 +268,7 @@ def get_encoder_model(params: AttributeDict) -> nn.Module: # TODO: We can add an option to switch between Conformer and Transformer encoder = Conformer( num_features=params.feature_dim, - output_dim=params.vocab_size, + output_dim=params.encoder_out_dim, subsampling_factor=params.subsampling_factor, d_model=params.attention_dim, nhead=params.nhead, @@ -279,9 +280,12 @@ def get_encoder_model(params: AttributeDict) -> nn.Module: def get_decoder_model(params: AttributeDict) -> nn.Module: + # Note: We set the embedding_dim of the decoder to + # vocab_size so that its output can be added with + # that of the encoder decoder = Decoder( vocab_size=params.vocab_size, - embedding_dim=params.embedding_dim, + embedding_dim=params.encoder_out_dim, blank_id=params.blank_id, context_size=params.context_size, ) @@ -290,8 +294,7 @@ def get_decoder_model(params: AttributeDict) -> nn.Module: def get_joiner_model(params: AttributeDict) -> nn.Module: joiner = Joiner( - input_dim=params.vocab_size, - inner_dim=params.embedding_dim, + input_dim=params.encoder_out_dim, output_dim=params.vocab_size, ) return joiner