# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) # # See ../../../../LICENSE for clarification regarding multiple authors # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Optional import torch import torch.nn as nn import torch.nn.functional as F from subsampling import ScaledConv1d from torch import Tensor class Decoder(nn.Module): """This class modifies the stateless decoder from the following paper: RNN-transducer with stateless prediction network https://ieeexplore.ieee.org/stamp/stamp.jsp?arnumber=9054419 It removes the recurrent connection from the decoder, i.e., the prediction network. Different from the above paper, it adds an extra Conv1d right after the embedding layer. TODO: Implement https://arxiv.org/pdf/2109.07513.pdf """ def __init__( self, vocab_size: int, embedding_dim: int, blank_id: int, context_size: int, ): """ Args: vocab_size: Number of tokens of the modeling unit including blank. embedding_dim: Dimension of the input embedding. blank_id: The ID of the blank symbol. context_size: Number of previous words to use to predict the next word. 1 means bigram; 2 means trigram. n means (n+1)-gram. """ super().__init__() self.embedding = ScaledEmbedding( num_embeddings=vocab_size, embedding_dim=embedding_dim, padding_idx=blank_id, ) self.blank_id = blank_id assert context_size >= 1, context_size self.context_size = context_size if context_size > 1: self.conv = ScaledConv1d( in_channels=embedding_dim, out_channels=embedding_dim, kernel_size=context_size, padding=0, groups=embedding_dim, bias=False, ) def forward(self, y: torch.Tensor, need_pad: bool = True) -> torch.Tensor: """ Args: y: A 2-D tensor of shape (N, U). need_pad: True to left pad the input. Should be True during training. False to not pad the input. Should be False during inference. Returns: Return a tensor of shape (N, U, embedding_dim). """ y = y.to(torch.int64) embedding_out = self.embedding(y) if self.context_size > 1: embedding_out = embedding_out.permute(0, 2, 1) if need_pad is True: embedding_out = F.pad(embedding_out, pad=(self.context_size - 1, 0)) else: # During inference time, there is no need to do extra padding # as we only need one output assert embedding_out.size(-1) == self.context_size embedding_out = self.conv(embedding_out) embedding_out = embedding_out.permute(0, 2, 1) return embedding_out class ScaledEmbedding(nn.Module): r"""A simple lookup table that stores embeddings of a fixed dictionary and size. This module is often used to store word embeddings and retrieve them using indices. The input to the module is a list of indices, and the output is the corresponding word embeddings. Args: num_embeddings (int): size of the dictionary of embeddings embedding_dim (int): the size of each embedding vector padding_idx (int, optional): If given, pads the output with the embedding vector at :attr:`padding_idx` (initialized to zeros) whenever it encounters the index. max_norm (float, optional): If given, each embedding vector with norm larger than :attr:`max_norm` is renormalized to have norm :attr:`max_norm`. norm_type (float, optional): The p of the p-norm to compute for the :attr:`max_norm` option. Default ``2``. scale_grad_by_freq (boolean, optional): If given, this will scale gradients by the inverse of frequency of the words in the mini-batch. Default ``False``. sparse (bool, optional): If ``True``, gradient w.r.t. :attr:`weight` matrix will be a sparse tensor. See Notes for more details regarding sparse gradients. Attributes: weight (Tensor): the learnable weights of the module of shape (num_embeddings, embedding_dim) initialized from :math:`\mathcal{N}(0, 1)` Shape: - Input: :math:`(*)`, LongTensor of arbitrary shape containing the indices to extract - Output: :math:`(*, H)`, where `*` is the input shape and :math:`H=\text{embedding\_dim}` .. note:: Keep in mind that only a limited number of optimizers support sparse gradients: currently it's :class:`optim.SGD` (`CUDA` and `CPU`), :class:`optim.SparseAdam` (`CUDA` and `CPU`) and :class:`optim.Adagrad` (`CPU`) .. note:: With :attr:`padding_idx` set, the embedding vector at :attr:`padding_idx` is initialized to all zeros. However, note that this vector can be modified afterwards, e.g., using a customized initialization method, and thus changing the vector used to pad the output. The gradient for this vector from :class:`~torch.nn.Embedding` is always zero. Examples:: >>> # an Embedding module containing 10 tensors of size 3 >>> embedding = nn.Embedding(10, 3) >>> # a batch of 2 samples of 4 indices each >>> input = torch.LongTensor([[1,2,4,5],[4,3,2,9]]) >>> embedding(input) tensor([[[-0.0251, -1.6902, 0.7172], [-0.6431, 0.0748, 0.6969], [ 1.4970, 1.3448, -0.9685], [-0.3677, -2.7265, -0.1685]], [[ 1.4970, 1.3448, -0.9685], [ 0.4362, -0.4004, 0.9400], [-0.6431, 0.0748, 0.6969], [ 0.9124, -2.3616, 1.1151]]]) >>> # example with padding_idx >>> embedding = nn.Embedding(10, 3, padding_idx=0) >>> input = torch.LongTensor([[0,2,0,5]]) >>> embedding(input) tensor([[[ 0.0000, 0.0000, 0.0000], [ 0.1535, -2.0309, 0.9315], [ 0.0000, 0.0000, 0.0000], [-0.1655, 0.9897, 0.0635]]]) """ __constants__ = [ "num_embeddings", "embedding_dim", "padding_idx", "scale_grad_by_freq", "sparse", ] num_embeddings: int embedding_dim: int padding_idx: int scale_grad_by_freq: bool weight: Tensor sparse: bool def __init__( self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None, scale_grad_by_freq: bool = False, sparse: bool = False, scale_speed: float = 5.0, ) -> None: super(ScaledEmbedding, self).__init__() self.num_embeddings = num_embeddings self.embedding_dim = embedding_dim if padding_idx is not None: if padding_idx > 0: assert ( padding_idx < self.num_embeddings ), "Padding_idx must be within num_embeddings" elif padding_idx < 0: assert ( padding_idx >= -self.num_embeddings ), "Padding_idx must be within num_embeddings" padding_idx = self.num_embeddings + padding_idx self.padding_idx = padding_idx self.scale_grad_by_freq = scale_grad_by_freq self.scale_speed = scale_speed self.scale = nn.Parameter(torch.zeros(())) # see reset_parameters() self.sparse = sparse self.weight = nn.Parameter(torch.Tensor(num_embeddings, embedding_dim)) self.reset_parameters() def reset_parameters(self) -> None: nn.init.normal_(self.weight, std=0.05) nn.init.constant_(self.scale, torch.tensor(1.0 / 0.05).log() / self.scale_speed) if self.padding_idx is not None: with torch.no_grad(): self.weight[self.padding_idx].fill_(0) def forward(self, input: Tensor) -> Tensor: scale = (self.scale * self.scale_speed).exp() if input.numel() < self.num_embeddings: return ( F.embedding( input, self.weight, self.padding_idx, None, 2.0, # None, 2.0 relate to normalization self.scale_grad_by_freq, self.sparse, ) * scale ) else: return F.embedding( input, self.weight * scale, self.padding_idx, None, 2.0, # None, 2.0 relates to normalization self.scale_grad_by_freq, self.sparse, ) def extra_repr(self) -> str: s = ( "{num_embeddings}, {embedding_dim}, scale_speed={scale_speed}," " scale={scale}" ) if self.padding_idx is not None: s += ", padding_idx={padding_idx}" if self.scale_grad_by_freq is not False: s += ", scale_grad_by_freq={scale_grad_by_freq}" if self.sparse is not False: s += ", sparse=True" return s.format(**self.__dict__)