# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) # # See ../../../../LICENSE for clarification regarding multiple authors # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import torch import torch.nn as nn import torch.nn.functional as F class Decoder(nn.Module): """This class modifies the stateless decoder from the following paper: RNN-transducer with stateless prediction network https://ieeexplore.ieee.org/stamp/stamp.jsp?arnumber=9054419 It removes the recurrent connection from the decoder, i.e., the prediction network. Different from the above paper, it adds an extra Conv1d right after the embedding layer. TODO: Implement https://arxiv.org/pdf/2109.07513.pdf """ def __init__( self, vocab_size: int, embedding_dim: int, blank_id: int, unk_id: int, context_size: int, ): """ Args: vocab_size: Number of tokens of the modeling unit including blank. embedding_dim: Dimension of the input embedding. blank_id: The ID of the blank symbol. unk_id: The ID of the unk symbol. context_size: Number of previous words to use to predict the next word. 1 means bigram; 2 means trigram. n means (n+1)-gram. """ super().__init__() self.embedding = nn.Embedding( num_embeddings=vocab_size, embedding_dim=embedding_dim, padding_idx=blank_id, ) self.blank_id = blank_id self.unk_id = unk_id assert context_size >= 1, context_size self.context_size = context_size self.vocab_size = vocab_size if context_size > 1: self.conv = nn.Conv1d( in_channels=embedding_dim, out_channels=embedding_dim, kernel_size=context_size, padding=0, groups=embedding_dim, bias=False, ) self.output_linear = nn.Linear(embedding_dim, vocab_size) def forward(self, y: torch.Tensor, need_pad: bool = True) -> torch.Tensor: """ Args: y: A 2-D tensor of shape (N, U) with blank prepended. need_pad: True to left pad the input. Should be True during training. False to not pad the input. Should be False during inference. Returns: Return a tensor of shape (N, U, embedding_dim). """ embedding_out = self.embedding(y) if self.context_size > 1: embedding_out = embedding_out.permute(0, 2, 1) if need_pad is True: embedding_out = F.pad( embedding_out, pad=(self.context_size - 1, 0) ) else: # During inference time, there is no need to do extra padding # as we only need one output assert embedding_out.size(-1) == self.context_size embedding_out = self.conv(embedding_out) embedding_out = embedding_out.permute(0, 2, 1) embedding_out = self.output_linear(F.relu(embedding_out)) return embedding_out