From 6769087d702b3b8fed473e2da487772622be26c1 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 18 Mar 2022 16:31:25 +0800 Subject: [PATCH] Remove scale_speed, make swish deriv more efficient. --- .../pruned_transducer_stateless2/conformer.py | 6 +- .../pruned_transducer_stateless2/decoder.py | 138 +---------- .../pruned_transducer_stateless2/scaling.py | 222 ++++++++++++++---- 3 files changed, 181 insertions(+), 185 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py index 245af05e3..cb4652840 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py @@ -410,7 +410,6 @@ class RelPositionMultiheadAttention(nn.Module): embed_dim: int, num_heads: int, dropout: float = 0.0, - scale_speed: float = 5.0 ) -> None: super(RelPositionMultiheadAttention, self).__init__() self.embed_dim = embed_dim @@ -430,16 +429,15 @@ class RelPositionMultiheadAttention(nn.Module): # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 self.pos_bias_u = nn.Parameter(torch.Tensor(num_heads, self.head_dim)) self.pos_bias_v = nn.Parameter(torch.Tensor(num_heads, self.head_dim)) - self.scale_speed = scale_speed self.pos_bias_u_scale = nn.Parameter(torch.zeros(()).detach()) self.pos_bias_v_scale = nn.Parameter(torch.zeros(()).detach()) self._reset_parameters() def _pos_bias_u(self): - return self.pos_bias_u * (self.pos_bias_u_scale * self.scale_speed).exp() + return self.pos_bias_u * self.pos_bias_u_scale.exp() def _pos_bias_v(self): - return self.pos_bias_v * (self.pos_bias_v_scale * self.scale_speed).exp() + return self.pos_bias_v * self.pos_bias_v_scale.exp() def _reset_parameters(self) -> None: nn.init.normal_(self.pos_bias_u, std=0.05) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/decoder.py b/egs/librispeech/ASR/pruned_transducer_stateless2/decoder.py index 7836ca999..47a519dc9 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/decoder.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/decoder.py @@ -19,7 +19,7 @@ import torch.nn as nn import torch.nn.functional as F from torch import Tensor from typing import Optional -from scaling import ScaledConv1d, ScaledLinear +from scaling import ScaledConv1d, ScaledLinear, ScaledEmbedding class Decoder(nn.Module): @@ -103,139 +103,3 @@ class Decoder(nn.Module): embedding_out = embedding_out.permute(0, 2, 1) embedding_out = self.output_linear(F.relu(embedding_out)) return embedding_out - - - -class ScaledEmbedding(nn.Module): - r"""A simple lookup table that stores embeddings of a fixed dictionary and size. - - This module is often used to store word embeddings and retrieve them using indices. - The input to the module is a list of indices, and the output is the corresponding - word embeddings. - - Args: - num_embeddings (int): size of the dictionary of embeddings - embedding_dim (int): the size of each embedding vector - padding_idx (int, optional): If given, pads the output with the embedding vector at :attr:`padding_idx` - (initialized to zeros) whenever it encounters the index. - max_norm (float, optional): If given, each embedding vector with norm larger than :attr:`max_norm` - is renormalized to have norm :attr:`max_norm`. - norm_type (float, optional): The p of the p-norm to compute for the :attr:`max_norm` option. Default ``2``. - scale_grad_by_freq (boolean, optional): If given, this will scale gradients by the inverse of frequency of - the words in the mini-batch. Default ``False``. - sparse (bool, optional): If ``True``, gradient w.r.t. :attr:`weight` matrix will be a sparse tensor. - See Notes for more details regarding sparse gradients. - - Attributes: - weight (Tensor): the learnable weights of the module of shape (num_embeddings, embedding_dim) - initialized from :math:`\mathcal{N}(0, 1)` - - Shape: - - Input: :math:`(*)`, LongTensor of arbitrary shape containing the indices to extract - - Output: :math:`(*, H)`, where `*` is the input shape and :math:`H=\text{embedding\_dim}` - - .. note:: - Keep in mind that only a limited number of optimizers support - sparse gradients: currently it's :class:`optim.SGD` (`CUDA` and `CPU`), - :class:`optim.SparseAdam` (`CUDA` and `CPU`) and :class:`optim.Adagrad` (`CPU`) - - .. note:: - With :attr:`padding_idx` set, the embedding vector at - :attr:`padding_idx` is initialized to all zeros. However, note that this - vector can be modified afterwards, e.g., using a customized - initialization method, and thus changing the vector used to pad the - output. The gradient for this vector from :class:`~torch.nn.Embedding` - is always zero. - - Examples:: - - >>> # an Embedding module containing 10 tensors of size 3 - >>> embedding = nn.Embedding(10, 3) - >>> # a batch of 2 samples of 4 indices each - >>> input = torch.LongTensor([[1,2,4,5],[4,3,2,9]]) - >>> embedding(input) - tensor([[[-0.0251, -1.6902, 0.7172], - [-0.6431, 0.0748, 0.6969], - [ 1.4970, 1.3448, -0.9685], - [-0.3677, -2.7265, -0.1685]], - - [[ 1.4970, 1.3448, -0.9685], - [ 0.4362, -0.4004, 0.9400], - [-0.6431, 0.0748, 0.6969], - [ 0.9124, -2.3616, 1.1151]]]) - - - >>> # example with padding_idx - >>> embedding = nn.Embedding(10, 3, padding_idx=0) - >>> input = torch.LongTensor([[0,2,0,5]]) - >>> embedding(input) - tensor([[[ 0.0000, 0.0000, 0.0000], - [ 0.1535, -2.0309, 0.9315], - [ 0.0000, 0.0000, 0.0000], - [-0.1655, 0.9897, 0.0635]]]) - """ - __constants__ = ['num_embeddings', 'embedding_dim', 'padding_idx', - 'scale_grad_by_freq', 'sparse'] - - num_embeddings: int - embedding_dim: int - padding_idx: int - scale_grad_by_freq: bool - weight: Tensor - sparse: bool - - def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None, - scale_grad_by_freq: bool = False, - sparse: bool = False, - scale_speed: float = 5.0) -> None: - super(ScaledEmbedding, self).__init__() - self.num_embeddings = num_embeddings - self.embedding_dim = embedding_dim - if padding_idx is not None: - if padding_idx > 0: - assert padding_idx < self.num_embeddings, 'Padding_idx must be within num_embeddings' - elif padding_idx < 0: - assert padding_idx >= -self.num_embeddings, 'Padding_idx must be within num_embeddings' - padding_idx = self.num_embeddings + padding_idx - self.padding_idx = padding_idx - self.scale_grad_by_freq = scale_grad_by_freq - - self.scale_speed = scale_speed - self.scale = nn.Parameter(torch.zeros(())) # see reset_parameters() - self.sparse = sparse - - self.weight = nn.Parameter(torch.Tensor(num_embeddings, embedding_dim)) - self.reset_parameters() - - - - def reset_parameters(self) -> None: - nn.init.normal_(self.weight, std=0.05) - nn.init.constant_(self.scale, torch.tensor(1.0/0.05).log() / self.scale_speed) - - if self.padding_idx is not None: - with torch.no_grad(): - self.weight[self.padding_idx].fill_(0) - - def forward(self, input: Tensor) -> Tensor: - scale = (self.scale * self.scale_speed).exp() - if input.numel() < self.num_embeddings: - return F.embedding( - input, self.weight, self.padding_idx, - None, 2.0, # None, 2.0 relate to normalization - self.scale_grad_by_freq, self.sparse) * scale - else: - return F.embedding( - input, self.weight * scale, self.padding_idx, - None, 2.0, # None, 2.0 relates to normalization - self.scale_grad_by_freq, self.sparse) - - def extra_repr(self) -> str: - s = '{num_embeddings}, {embedding_dim}, scale_speed={scale_speed}, scale={scale}' - if self.padding_idx is not None: - s += ', padding_idx={padding_idx}' - if self.scale_grad_by_freq is not False: - s += ', scale_grad_by_freq={scale_grad_by_freq}' - if self.sparse is not False: - s += ', sparse=True' - return s.format(**self.__dict__) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py index c8bc35fd1..f0e3fe148 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py @@ -18,7 +18,7 @@ import torch import torch.nn as nn from torch import Tensor -from typing import Tuple +from typing import Tuple, Optional @@ -94,31 +94,25 @@ class BasicNorm(torch.nn.Module): to indicate the connection with conventional LayerNorm. learn_eps: if true, we learn epsilon; if false, we keep it at the initial value. - eps_speed: a constant that determines how fast "eps" learns; - with Adam and variants, this should probably be >= 1, - e.g. 5.0. For SGD and variants, probably a value less than one, - like 0.1, would be suitable, to prevent instability. """ def __init__(self, num_channels: int, channel_dim: int = -1, # CAUTION: see documentation. eps: float = 0.25, - learn_eps: bool = True, - eps_speed: float = 5.0): + learn_eps: bool = True) -> None: super(BasicNorm, self).__init__() self.num_channels = num_channels self.channel_dim = channel_dim - self.eps_speed = eps_speed if learn_eps: - self.eps = nn.Parameter((torch.tensor(eps).log() / eps_speed).detach()) + self.eps = nn.Parameter(torch.tensor(eps).log().detach()) else: - self.register_buffer('eps', (torch.tensor(eps).log() / eps_speed).detach()) + self.register_buffer('eps', torch.tensor(eps).log().detach()) def forward(self, x: Tensor) -> Tensor: assert x.shape[self.channel_dim] == self.num_channels scales = (torch.mean(x**2, dim=self.channel_dim, keepdim=True) + - (self.eps * self.eps_speed).exp()) ** -0.5 + self.eps.exp()) ** -0.5 return x * scales @@ -128,16 +122,13 @@ class ScaledLinear(nn.Linear): """ A modified version of nn.Linear where the parameters are scaled before use, via: - weight = self.weight * (self.weight_scale * self.scale_speed).exp() - bias = self.bias * (self.bias_scale * self.scale_speed).exp() + weight = self.weight * self.weight_scale.exp() + bias = self.bias * self.bias_scale.exp() Args: Accepts the standard args and kwargs that nn.Linear accepts e.g. in_features, out_features, bias=False. - scale_speed: a factor that affects how fast the weight_scale - and bias_scale learn; this value is suitable for Adam-type - optimizers. initial_scale: you can override this if you want to increase or decrease the initial magnitude of the module's output (affects the initialization of weight_scale and bias_scale). @@ -149,13 +140,11 @@ class ScaledLinear(nn.Linear): may be larger than optimal. """ def __init__(self, *args, - scale_speed: float = 5.0, initial_scale: float = 1.0, **kwargs): super(ScaledLinear, self).__init__(*args, **kwargs) - initial_scale = (torch.tensor(initial_scale).log() / scale_speed) + initial_scale = torch.tensor(initial_scale).log() self.weight_scale = nn.Parameter(initial_scale.clone().detach()) - self.scale_speed = scale_speed if self.bias is not None: self.bias_scale = nn.Parameter(initial_scale.clone().detach()) else: @@ -172,14 +161,14 @@ class ScaledLinear(nn.Linear): fan_in = self.weight.shape[1] * self.weight[0][0].numel() scale = fan_in ** -0.5 # 1/sqrt(fan_in) with torch.no_grad(): - self.weight_scale += (torch.tensor(scale / std).log() / self.scale_speed) + self.weight_scale += torch.tensor(scale / std).log() def get_weight(self): - return self.weight * (self.weight_scale * self.scale_speed).exp() + return self.weight * self.weight_scale.exp() def get_bias(self): return (None if self.bias is None else - self.bias * (self.bias_scale * self.scale_speed).exp()) + self.bias * self.bias_scale.exp()) def forward(self, input: Tensor) -> Tensor: return torch.nn.functional.linear(input, self.get_weight(), @@ -187,11 +176,10 @@ class ScaledLinear(nn.Linear): class ScaledConv1d(nn.Conv1d): - def __init__(self, *args, scale_speed = 5.0, + def __init__(self, *args, initial_scale=1.0, **kwargs): super(ScaledConv1d, self).__init__(*args, **kwargs) - self.scale_speed = scale_speed - initial_scale = (torch.tensor(initial_scale).log() / scale_speed) + initial_scale = torch.tensor(initial_scale).log() self.weight_scale = nn.Parameter(initial_scale.clone().detach()) if self.bias is not None: self.bias_scale = nn.Parameter(initial_scale.clone().detach()) @@ -208,15 +196,15 @@ class ScaledConv1d(nn.Conv1d): fan_in = self.weight.shape[1] * self.weight[0][0].numel() scale = fan_in ** -0.5 # 1/sqrt(fan_in) with torch.no_grad(): - self.weight_scale += (torch.tensor(scale / std).log() / self.scale_speed) + self.weight_scale += torch.tensor(scale / std).log() def get_weight(self): - return self.weight * (self.weight_scale * self.scale_speed).exp() + return self.weight * self.weight_scale.exp() def get_bias(self): return (None if self.bias is None else - self.bias * (self.bias_scale * self.scale_speed).exp()) + self.bias * self.bias_scale.exp()) def forward(self, input: Tensor) -> Tensor: F = torch.nn.functional @@ -230,10 +218,9 @@ class ScaledConv1d(nn.Conv1d): class ScaledConv2d(nn.Conv2d): - def __init__(self, *args, scale_speed=5.0, initial_scale=1.0, **kwargs): + def __init__(self, *args, initial_scale=1.0, **kwargs): super(ScaledConv2d, self).__init__(*args, **kwargs) - self.scale_speed = scale_speed - initial_scale = (torch.tensor(initial_scale).log() / scale_speed) + initial_scale = torch.tensor(initial_scale).log() self.weight_scale = nn.Parameter(initial_scale.clone().detach()) if self.bias is not None: self.bias_scale = nn.Parameter(initial_scale.clone().detach()) @@ -250,15 +237,15 @@ class ScaledConv2d(nn.Conv2d): fan_in = self.weight.shape[1] * self.weight[0][0].numel() scale = fan_in ** -0.5 # 1/sqrt(fan_in) with torch.no_grad(): - self.weight_scale += (torch.tensor(scale / std).log() / self.scale_speed) + self.weight_scale += torch.tensor(scale / std).log() def get_weight(self): - return self.weight * (self.weight_scale * self.scale_speed).exp() + return self.weight * self.weight_scale.exp() def get_bias(self): return (None if self.bias is None else - self.bias * (self.bias_scale * self.scale_speed).exp()) + self.bias * self.bias_scale.exp()) def _conv_forward(self, input, weight): F = torch.nn.functional @@ -323,6 +310,16 @@ class ActivationBalancer(torch.nn.Module): self.max_factor, self.min_abs, self.max_abs) +# deriv of double_swish: +# double_swish(x) = x * torch.sigmoid(x-1) [this is a definition, originally +# motivated by its similarity to swish(swish(x), +# where swish(x) = x *sigmoid(x)]. +# double_swish(x) = x * s, where s(x) = torch.sigmoid(x-1) +# double_swish'(x) = d/dx double_swish(x) = x * s'(x) + x' * s(x) = x * s'(x) + s(x). +# Now, s'(x) = s(x) * (1-s(x)). +# double_swish'(x) = x * s'(x) + s(x). +# = x * s(x) * (1-s(x)) + s(x). +# = double_swish(x) * (1-s(x)) + s(x) def _double_swish(x: Tensor) -> Tensor: # double-swish, implemented/approximated as offset-swish @@ -331,18 +328,16 @@ def _double_swish(x: Tensor) -> Tensor: class DoubleSwishFunction(torch.autograd.Function): @staticmethod def forward(ctx, x: Tensor) -> Tensor: - ctx.save_for_backward(x.detach()) - return _double_swish(x) + x = x.detach() + s = torch.sigmoid(x - 1.0) + y = x * s + ctx.save_for_backward(s, y) + return y @staticmethod def backward(ctx, y_grad: Tensor) -> Tensor: - # TODO: can make this more efficient. - x, = ctx.saved_tensors - x.requires_grad = True - with torch.enable_grad(): - y = _double_swish(x) - y.backward(gradient=y_grad) - return x.grad + s, y = ctx.saved_tensors + return (y * (1-s) + s) * y_grad class DoubleSwish(torch.nn.Module): def forward(self, x: Tensor) -> Tensor: @@ -353,6 +348,140 @@ class DoubleSwish(torch.nn.Module): + +class ScaledEmbedding(nn.Module): + r"""A simple lookup table that stores embeddings of a fixed dictionary and size. + + This module is often used to store word embeddings and retrieve them using indices. + The input to the module is a list of indices, and the output is the corresponding + word embeddings. + + Args: + num_embeddings (int): size of the dictionary of embeddings + embedding_dim (int): the size of each embedding vector + padding_idx (int, optional): If given, pads the output with the embedding vector at :attr:`padding_idx` + (initialized to zeros) whenever it encounters the index. + max_norm (float, optional): If given, each embedding vector with norm larger than :attr:`max_norm` + is renormalized to have norm :attr:`max_norm`. + norm_type (float, optional): The p of the p-norm to compute for the :attr:`max_norm` option. Default ``2``. + scale_grad_by_freq (boolean, optional): If given, this will scale gradients by the inverse of frequency of + the words in the mini-batch. Default ``False``. + sparse (bool, optional): If ``True``, gradient w.r.t. :attr:`weight` matrix will be a sparse tensor. + See Notes for more details regarding sparse gradients. + + Attributes: + weight (Tensor): the learnable weights of the module of shape (num_embeddings, embedding_dim) + initialized from :math:`\mathcal{N}(0, 1)` + + Shape: + - Input: :math:`(*)`, LongTensor of arbitrary shape containing the indices to extract + - Output: :math:`(*, H)`, where `*` is the input shape and :math:`H=\text{embedding\_dim}` + + .. note:: + Keep in mind that only a limited number of optimizers support + sparse gradients: currently it's :class:`optim.SGD` (`CUDA` and `CPU`), + :class:`optim.SparseAdam` (`CUDA` and `CPU`) and :class:`optim.Adagrad` (`CPU`) + + .. note:: + With :attr:`padding_idx` set, the embedding vector at + :attr:`padding_idx` is initialized to all zeros. However, note that this + vector can be modified afterwards, e.g., using a customized + initialization method, and thus changing the vector used to pad the + output. The gradient for this vector from :class:`~torch.nn.Embedding` + is always zero. + + Examples:: + + >>> # an Embedding module containing 10 tensors of size 3 + >>> embedding = nn.Embedding(10, 3) + >>> # a batch of 2 samples of 4 indices each + >>> input = torch.LongTensor([[1,2,4,5],[4,3,2,9]]) + >>> embedding(input) + tensor([[[-0.0251, -1.6902, 0.7172], + [-0.6431, 0.0748, 0.6969], + [ 1.4970, 1.3448, -0.9685], + [-0.3677, -2.7265, -0.1685]], + + [[ 1.4970, 1.3448, -0.9685], + [ 0.4362, -0.4004, 0.9400], + [-0.6431, 0.0748, 0.6969], + [ 0.9124, -2.3616, 1.1151]]]) + + + >>> # example with padding_idx + >>> embedding = nn.Embedding(10, 3, padding_idx=0) + >>> input = torch.LongTensor([[0,2,0,5]]) + >>> embedding(input) + tensor([[[ 0.0000, 0.0000, 0.0000], + [ 0.1535, -2.0309, 0.9315], + [ 0.0000, 0.0000, 0.0000], + [-0.1655, 0.9897, 0.0635]]]) + """ + __constants__ = ['num_embeddings', 'embedding_dim', 'padding_idx', + 'scale_grad_by_freq', 'sparse'] + + num_embeddings: int + embedding_dim: int + padding_idx: int + scale_grad_by_freq: bool + weight: Tensor + sparse: bool + + def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None, + scale_grad_by_freq: bool = False, + sparse: bool = False) -> None: + super(ScaledEmbedding, self).__init__() + self.num_embeddings = num_embeddings + self.embedding_dim = embedding_dim + if padding_idx is not None: + if padding_idx > 0: + assert padding_idx < self.num_embeddings, 'Padding_idx must be within num_embeddings' + elif padding_idx < 0: + assert padding_idx >= -self.num_embeddings, 'Padding_idx must be within num_embeddings' + padding_idx = self.num_embeddings + padding_idx + self.padding_idx = padding_idx + self.scale_grad_by_freq = scale_grad_by_freq + + self.scale = nn.Parameter(torch.zeros(())) # see reset_parameters() + self.sparse = sparse + + self.weight = nn.Parameter(torch.Tensor(num_embeddings, embedding_dim)) + self.reset_parameters() + + + + def reset_parameters(self) -> None: + nn.init.normal_(self.weight, std=0.05) + nn.init.constant_(self.scale, torch.tensor(1.0/0.05).log()) + + if self.padding_idx is not None: + with torch.no_grad(): + self.weight[self.padding_idx].fill_(0) + + def forward(self, input: Tensor) -> Tensor: + scale = self.scale.exp() + if input.numel() < self.num_embeddings: + return F.embedding( + input, self.weight, self.padding_idx, + None, 2.0, # None, 2.0 relate to normalization + self.scale_grad_by_freq, self.sparse) * scale + else: + return F.embedding( + input, self.weight * scale, self.padding_idx, + None, 2.0, # None, 2.0 relates to normalization + self.scale_grad_by_freq, self.sparse) + + def extra_repr(self) -> str: + s = '{num_embeddings}, {embedding_dim}, scale={scale}' + if self.padding_idx is not None: + s += ', padding_idx={padding_idx}' + if self.scale_grad_by_freq is not False: + s += ', scale_grad_by_freq={scale_grad_by_freq}' + if self.sparse is not False: + s += ', sparse=True' + return s.format(**self.__dict__) + + def _test_activation_balancer_sign(): channel_dim = 0 probs = torch.arange(0, 1, 0.01) @@ -409,10 +538,15 @@ def _test_basic_norm(): assert y_rms > 0.5 * x_rms - +def _test_double_swish_deriv(): + x = torch.randn(10, 12, dtype=torch.double) * 0.5 + x.requires_grad = True + m = DoubleSwish() + torch.autograd.gradcheck(m, x) if __name__ == '__main__': _test_activation_balancer_sign() _test_activation_balancer_magnitude() _test_basic_norm() + _test_double_swish_deriv()