# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) # # See ../../../../LICENSE for clarification regarding multiple authors # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import torch import torch.nn as nn class Decoder(nn.Module): """This class implements the stateless decoder from the following paper: RNN-transducer with stateless prediction network https://ieeexplore.ieee.org/stamp/stamp.jsp?arnumber=9054419 It removes the recurrent connection from the decoder, i.e., the prediction network. TODO: Implement https://arxiv.org/pdf/2109.07513.pdf """ def __init__( self, vocab_size: int, embedding_dim: int, blank_id: int, ): """ Args: vocab_size: Number of tokens of the modeling unit including blank. embedding_dim: Dimension of the input embedding. blank_id: The ID of the blank symbol. """ super().__init__() self.embedding = nn.Embedding( num_embeddings=vocab_size, embedding_dim=embedding_dim, padding_idx=blank_id, ) self.blank_id = blank_id def forward(self, y: torch.Tensor) -> torch.Tensor: """ Args: y: A 2-D tensor of shape (N, U) with blank prepended. Returns: Return a tensor of shape (N, U, embedding_dim). """ embeding_out = self.embedding(y) return embeding_out