diff --git a/egs/librispeech/ASR/transducer_stateless/decoder.py b/egs/librispeech/ASR/transducer_stateless/decoder.py index 003b03a2e..bc4bcb3f6 100644 --- a/egs/librispeech/ASR/transducer_stateless/decoder.py +++ b/egs/librispeech/ASR/transducer_stateless/decoder.py @@ -17,6 +17,9 @@ import torch import torch.nn as nn import torch.nn.functional as F +from torch import Tensor +from typing import Optional +from subsampling import ScaledConv1d class Decoder(nn.Module): @@ -52,7 +55,7 @@ class Decoder(nn.Module): 1 means bigram; 2 means trigram. n means (n+1)-gram. """ super().__init__() - self.embedding = nn.Embedding( + self.embedding = ScaledEmbedding( num_embeddings=vocab_size, embedding_dim=embedding_dim, padding_idx=blank_id, @@ -62,7 +65,7 @@ class Decoder(nn.Module): assert context_size >= 1, context_size self.context_size = context_size if context_size > 1: - self.conv = nn.Conv1d( + self.conv = ScaledConv1d( in_channels=embedding_dim, out_channels=embedding_dim, kernel_size=context_size, @@ -97,3 +100,183 @@ class Decoder(nn.Module): embedding_out = self.conv(embedding_out) embedding_out = embedding_out.permute(0, 2, 1) return embedding_out + + + +class ScaledEmbedding(nn.Module): + r"""A simple lookup table that stores embeddings of a fixed dictionary and size. + + This module is often used to store word embeddings and retrieve them using indices. + The input to the module is a list of indices, and the output is the corresponding + word embeddings. + + Args: + num_embeddings (int): size of the dictionary of embeddings + embedding_dim (int): the size of each embedding vector + padding_idx (int, optional): If given, pads the output with the embedding vector at :attr:`padding_idx` + (initialized to zeros) whenever it encounters the index. + max_norm (float, optional): If given, each embedding vector with norm larger than :attr:`max_norm` + is renormalized to have norm :attr:`max_norm`. + norm_type (float, optional): The p of the p-norm to compute for the :attr:`max_norm` option. Default ``2``. + scale_grad_by_freq (boolean, optional): If given, this will scale gradients by the inverse of frequency of + the words in the mini-batch. Default ``False``. + sparse (bool, optional): If ``True``, gradient w.r.t. :attr:`weight` matrix will be a sparse tensor. + See Notes for more details regarding sparse gradients. + + Attributes: + weight (Tensor): the learnable weights of the module of shape (num_embeddings, embedding_dim) + initialized from :math:`\mathcal{N}(0, 1)` + + Shape: + - Input: :math:`(*)`, LongTensor of arbitrary shape containing the indices to extract + - Output: :math:`(*, H)`, where `*` is the input shape and :math:`H=\text{embedding\_dim}` + + .. note:: + Keep in mind that only a limited number of optimizers support + sparse gradients: currently it's :class:`optim.SGD` (`CUDA` and `CPU`), + :class:`optim.SparseAdam` (`CUDA` and `CPU`) and :class:`optim.Adagrad` (`CPU`) + + .. note:: + With :attr:`padding_idx` set, the embedding vector at + :attr:`padding_idx` is initialized to all zeros. However, note that this + vector can be modified afterwards, e.g., using a customized + initialization method, and thus changing the vector used to pad the + output. The gradient for this vector from :class:`~torch.nn.Embedding` + is always zero. + + Examples:: + + >>> # an Embedding module containing 10 tensors of size 3 + >>> embedding = nn.Embedding(10, 3) + >>> # a batch of 2 samples of 4 indices each + >>> input = torch.LongTensor([[1,2,4,5],[4,3,2,9]]) + >>> embedding(input) + tensor([[[-0.0251, -1.6902, 0.7172], + [-0.6431, 0.0748, 0.6969], + [ 1.4970, 1.3448, -0.9685], + [-0.3677, -2.7265, -0.1685]], + + [[ 1.4970, 1.3448, -0.9685], + [ 0.4362, -0.4004, 0.9400], + [-0.6431, 0.0748, 0.6969], + [ 0.9124, -2.3616, 1.1151]]]) + + + >>> # example with padding_idx + >>> embedding = nn.Embedding(10, 3, padding_idx=0) + >>> input = torch.LongTensor([[0,2,0,5]]) + >>> embedding(input) + tensor([[[ 0.0000, 0.0000, 0.0000], + [ 0.1535, -2.0309, 0.9315], + [ 0.0000, 0.0000, 0.0000], + [-0.1655, 0.9897, 0.0635]]]) + """ + __constants__ = ['num_embeddings', 'embedding_dim', 'padding_idx', + 'scale_grad_by_freq', 'sparse'] + + num_embeddings: int + embedding_dim: int + padding_idx: int + scale_grad_by_freq: bool + weight: Tensor + sparse: bool + + def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None, + scale_grad_by_freq: bool = False, + sparse: bool = False, _weight: Optional[Tensor] = None, + scale_speed: float = 5.0) -> None: + super(ScaledEmbedding, self).__init__() + self.num_embeddings = num_embeddings + self.embedding_dim = embedding_dim + if padding_idx is not None: + if padding_idx > 0: + assert padding_idx < self.num_embeddings, 'Padding_idx must be within num_embeddings' + elif padding_idx < 0: + assert padding_idx >= -self.num_embeddings, 'Padding_idx must be within num_embeddings' + padding_idx = self.num_embeddings + padding_idx + self.padding_idx = padding_idx + self.scale_grad_by_freq = scale_grad_by_freq + + self.scale_speed = scale_speed + self.scale = nn.Parameter(torch.tensor(embedding_dim**0.5).log() / scale_speed) + + if _weight is None: + self.weight = nn.Parameter(torch.Tensor(num_embeddings, embedding_dim)) + self.reset_parameters() + else: + assert list(_weight.shape) == [num_embeddings, embedding_dim], \ + 'Shape of weight does not match num_embeddings and embedding_dim' + self.weight = nn.Parameter(_weight) + self.sparse = sparse + + def reset_parameters(self) -> None: + nn.init.normal_(self.weight, std=self.embedding_dim**-0.5) + if self.padding_idx is not None: + with torch.no_grad(): + self.weight[self.padding_idx].fill_(0) + + def forward(self, input: Tensor) -> Tensor: + scale = (self.scale * self.scale_speed).exp() + if input.numel() < self.num_embeddings: + return F.embedding( + input, self.weight, self.padding_idx, + None, 2.0, # None, 2.0 relate to normalization + self.scale_grad_by_freq, self.sparse) * scale + else: + return F.embedding( + input, self.weight * scale, self.padding_idx, + None, 2.0, # None, 2.0 relates to normalization + self.scale_grad_by_freq, self.sparse) + + + def extra_repr(self) -> str: + s = '{num_embeddings}, {embedding_dim}, scale_speed={scale_speed}, scale={scale}' + if self.padding_idx is not None: + s += ', padding_idx={padding_idx}' + if self.scale_grad_by_freq is not False: + s += ', scale_grad_by_freq={scale_grad_by_freq}' + if self.sparse is not False: + s += ', sparse=True' + return s.format(**self.__dict__) + + @classmethod + def from_pretrained(cls, embeddings, freeze=True, padding_idx=None, + max_norm=None, norm_type=2., scale_grad_by_freq=False, + sparse=False): + r"""Creates Embedding instance from given 2-dimensional FloatTensor. + + Args: + embeddings (Tensor): FloatTensor containing weights for the Embedding. + First dimension is being passed to Embedding as ``num_embeddings``, second as ``embedding_dim``. + freeze (boolean, optional): If ``True``, the tensor does not get updated in the learning process. + Equivalent to ``embedding.weight.requires_grad = False``. Default: ``True`` + padding_idx (int, optional): See module initialization documentation. + max_norm (float, optional): See module initialization documentation. + norm_type (float, optional): See module initialization documentation. Default ``2``. + scale_grad_by_freq (boolean, optional): See module initialization documentation. Default ``False``. + sparse (bool, optional): See module initialization documentation. + + Examples:: + + >>> # FloatTensor containing pretrained weights + >>> weight = torch.FloatTensor([[1, 2.3, 3], [4, 5.1, 6.3]]) + >>> embedding = nn.Embedding.from_pretrained(weight) + >>> # Get embeddings for index 1 + >>> input = torch.LongTensor([1]) + >>> embedding(input) + tensor([[ 4.0000, 5.1000, 6.3000]]) + """ + assert embeddings.dim() == 2, \ + 'Embeddings parameter is expected to be 2-dimensional' + rows, cols = embeddings.shape + embedding = cls( + num_embeddings=rows, + embedding_dim=cols, + _weight=embeddings, + padding_idx=padding_idx, + max_norm=max_norm, + norm_type=norm_type, + scale_grad_by_freq=scale_grad_by_freq, + sparse=sparse) + embedding.weight.requires_grad = not freeze + return embedding diff --git a/egs/librispeech/ASR/transducer_stateless/joiner.py b/egs/librispeech/ASR/transducer_stateless/joiner.py index 9fd9da4f1..8311461d3 100644 --- a/egs/librispeech/ASR/transducer_stateless/joiner.py +++ b/egs/librispeech/ASR/transducer_stateless/joiner.py @@ -16,7 +16,7 @@ import torch import torch.nn as nn - +from subsampling import ScaledLinear class Joiner(nn.Module): def __init__(self, input_dim: int, output_dim: int): @@ -24,7 +24,7 @@ class Joiner(nn.Module): self.input_dim = input_dim self.output_dim = output_dim - self.output_linear = nn.Linear(input_dim, output_dim) + self.output_linear = ScaledLinear(input_dim, output_dim) def forward( self, diff --git a/egs/librispeech/ASR/transducer_stateless/train.py b/egs/librispeech/ASR/transducer_stateless/train.py index b871efd13..c2202fe1e 100755 --- a/egs/librispeech/ASR/transducer_stateless/train.py +++ b/egs/librispeech/ASR/transducer_stateless/train.py @@ -110,7 +110,7 @@ def get_parser(): parser.add_argument( "--exp-dir", type=str, - default="transducer_stateless/randcombine1_expscale3_rework2b", + default="transducer_stateless/randcombine1_expscale3_rework2c", help="""The experiment dir. It specifies the directory where all training related files, e.g., checkpoints, log, etc, are saved diff --git a/egs/librispeech/ASR/transducer_stateless/transformer.py b/egs/librispeech/ASR/transducer_stateless/transformer.py index e851dcc32..3fa847f4f 100644 --- a/egs/librispeech/ASR/transducer_stateless/transformer.py +++ b/egs/librispeech/ASR/transducer_stateless/transformer.py @@ -21,7 +21,7 @@ from typing import Optional, Tuple import torch import torch.nn as nn from encoder_interface import EncoderInterface -from subsampling import Conv2dSubsampling, VggSubsampling +from subsampling import Conv2dSubsampling, VggSubsampling, ScaledLinear from icefall.utils import make_pad_mask @@ -106,7 +106,7 @@ class Transformer(EncoderInterface): # TODO(fangjun): remove dropout self.encoder_output_layer = nn.Sequential( - nn.Dropout(p=dropout), nn.Linear(d_model, output_dim) + nn.Dropout(p=dropout), ScaledLinear(d_model, output_dim) ) def forward(