From 9165de5f57a2d14b51d430d8e2b32156ff4899c0 Mon Sep 17 00:00:00 2001 From: yaozengwei Date: Sat, 16 Jul 2022 22:47:05 +0800 Subject: [PATCH] add ScaledLSTM --- .../pruned_transducer_stateless2/scaling.py | 168 ++++++++++++++++-- 1 file changed, 157 insertions(+), 11 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py index c190be626..54f4a53c5 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py @@ -1,4 +1,4 @@ -# Copyright 2022 Xiaomi Corp. (authors: Daniel Povey) +# Copyright 2022 Xiaomi Corp. (authors: Daniel Povey, Zengwei Yao) # # See ../../../../LICENSE for clarification regarding multiple authors # @@ -21,7 +21,8 @@ from typing import Optional, Tuple import torch import torch.nn as nn -from torch import Tensor +from torch import Tensor, _VF +from torch.nn.utils.rnn import PackedSequence def _ntuple(n): @@ -154,7 +155,7 @@ class BasicNorm(torch.nn.Module): def forward(self, x: Tensor) -> Tensor: assert x.shape[self.channel_dim] == self.num_channels scales = ( - torch.mean(x ** 2, dim=self.channel_dim, keepdim=True) + torch.mean(x**2, dim=self.channel_dim, keepdim=True) + self.eps.exp() ) ** -0.5 return x * scales @@ -207,12 +208,12 @@ class ScaledLinear(nn.Linear): def _reset_parameters(self, initial_speed: float): std = 0.1 / initial_speed - a = (3 ** 0.5) * std + a = (3**0.5) * std nn.init.uniform_(self.weight, -a, a) if self.bias is not None: nn.init.constant_(self.bias, 0.0) fan_in = self.weight.shape[1] * self.weight[0][0].numel() - scale = fan_in ** -0.5 # 1/sqrt(fan_in) + scale = fan_in**-0.5 # 1/sqrt(fan_in) with torch.no_grad(): self.weight_scale += torch.tensor(scale / std).log() @@ -256,12 +257,12 @@ class ScaledConv1d(nn.Conv1d): def _reset_parameters(self, initial_speed: float): std = 0.1 / initial_speed - a = (3 ** 0.5) * std + a = (3**0.5) * std nn.init.uniform_(self.weight, -a, a) if self.bias is not None: nn.init.constant_(self.bias, 0.0) fan_in = self.weight.shape[1] * self.weight[0][0].numel() - scale = fan_in ** -0.5 # 1/sqrt(fan_in) + scale = fan_in**-0.5 # 1/sqrt(fan_in) with torch.no_grad(): self.weight_scale += torch.tensor(scale / std).log() @@ -325,12 +326,12 @@ class ScaledConv2d(nn.Conv2d): def _reset_parameters(self, initial_speed: float): std = 0.1 / initial_speed - a = (3 ** 0.5) * std + a = (3**0.5) * std nn.init.uniform_(self.weight, -a, a) if self.bias is not None: nn.init.constant_(self.bias, 0.0) fan_in = self.weight.shape[1] * self.weight[0][0].numel() - scale = fan_in ** -0.5 # 1/sqrt(fan_in) + scale = fan_in**-0.5 # 1/sqrt(fan_in) with torch.no_grad(): self.weight_scale += torch.tensor(scale / std).log() @@ -376,6 +377,137 @@ class ScaledConv2d(nn.Conv2d): return self._conv_forward(input, self.get_weight()) +class ScaledLSTM(nn.LSTM): + # See docs for ScaledLinear. + # This class implements single-layer LSTM with scaling mechanism, using `torch._VF.lstm` + # Please refer to https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/rnn.py + def __init__( + self, + *args, + initial_scale: float = 1.0, + initial_speed: float = 1.0, + **kwargs + ): + # Hardcode num_layers=1 and proj_size=0 here + super(ScaledLSTM, self).__init__( + *args, num_layers=1, proj_size=0, **kwargs + ) + initial_scale = torch.tensor(initial_scale).log() + self._scales_names = [] + self._scales = [] + for name in self._flat_weights_names: + scale_name = name + "_scale" + self._scales_names.append(scale_name) + param = nn.Parameter(initial_scale.clone().detach()) + setattr(self, scale_name, param) + self._scales.append(param) + + self._reset_parameters( + initial_speed + ) # Overrides the reset_parameters in base class + + def _reset_parameters(self, initial_speed: float): + std = 0.1 / initial_speed + a = (3**0.5) * std + fan_in = self.input_size + scale = fan_in**-0.5 + v = scale / std + for idx, name in enumerate(self._flat_weights_names): + if "weight" in name: + nn.init.uniform_(self._flat_weights[idx], -a, a) + with torch.no_grad(): + self._scales[idx] += torch.tensor(v).log() + elif "bias" in name: + nn.init.constant_(self._flat_weights[idx], 0.0) + + def get_flat_weights(self): + flat_weights = [] + for idx in range(len(self._flat_weights_names)): + flat_weights.append( + self._flat_weights[idx] * self._scales[idx].exp() + ) + return flat_weights + + def forward(self, input, hx=None): + # This function is copied from https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/rnn.py # noqa + # The only change is for calling `_VF.lstm()`: + # self._flat_weights -> self.get_flat_weights() + orig_input = input + # xxx: isinstance check needs to be in conditional for TorchScript to compile + if isinstance(orig_input, PackedSequence): + input, batch_sizes, sorted_indices, unsorted_indices = input + max_batch_size = batch_sizes[0] + max_batch_size = int(max_batch_size) + else: + batch_sizes = None + max_batch_size = ( + input.size(0) if self.batch_first else input.size(1) + ) + sorted_indices = None + unsorted_indices = None + + if hx is None: + num_directions = 2 if self.bidirectional else 1 + real_hidden_size = ( + self.proj_size if self.proj_size > 0 else self.hidden_size + ) + h_zeros = torch.zeros( + self.num_layers * num_directions, + max_batch_size, + real_hidden_size, + dtype=input.dtype, + device=input.device, + ) + c_zeros = torch.zeros( + self.num_layers * num_directions, + max_batch_size, + self.hidden_size, + dtype=input.dtype, + device=input.device, + ) + hx = (h_zeros, c_zeros) + else: + # Each batch of the hidden state should match the input sequence that + # the user believes he/she is passing in. + hx = self.permute_hidden(hx, sorted_indices) + + self.check_forward_args(input, hx, batch_sizes) + if batch_sizes is None: + result = _VF.lstm( + input, + hx, + self.get_flat_weights(), + self.bias, + self.num_layers, + self.dropout, + self.training, + self.bidirectional, + self.batch_first, + ) + else: + result = _VF.lstm( + input, + batch_sizes, + hx, + self.get_flat_weights(), + self.bias, + self.num_layers, + self.dropout, + self.training, + self.bidirectional, + ) + output = result[0] + hidden = result[1:] + # xxx: isinstance check needs to be in conditional for TorchScript to compile + if isinstance(orig_input, PackedSequence): + output_packed = PackedSequence( + output, batch_sizes, sorted_indices, unsorted_indices + ) + return output_packed, self.permute_hidden(hidden, unsorted_indices) + else: + return output, self.permute_hidden(hidden, unsorted_indices) + + class ActivationBalancer(torch.nn.Module): """ Modifies the backpropped derivatives of a function to try to encourage, for @@ -711,8 +843,8 @@ def _test_basic_norm(): y = m(x) assert y.shape == x.shape - x_rms = (x ** 2).mean().sqrt() - y_rms = (y ** 2).mean().sqrt() + x_rms = (x**2).mean().sqrt() + y_rms = (y**2).mean().sqrt() print("x rms = ", x_rms) print("y rms = ", y_rms) assert y_rms < x_rms @@ -726,8 +858,22 @@ def _test_double_swish_deriv(): torch.autograd.gradcheck(m, x) +def _test_scaled_lstm(): + N, L = 2, 30 + dim_in, dim_hidden = 10, 20 + m = ScaledLSTM(input_size=dim_in, hidden_size=dim_hidden, bias=True) + x = torch.randn(L, N, dim_in) + h0 = torch.randn(1, N, dim_hidden) + c0 = torch.randn(1, N, dim_hidden) + y, (h, c) = m(x, (h0, c0)) + assert y.shape == (L, N, dim_hidden) + assert h.shape == (1, N, dim_hidden) + assert c.shape == (1, N, dim_hidden) + + if __name__ == "__main__": _test_activation_balancer_sign() _test_activation_balancer_magnitude() _test_basic_norm() _test_double_swish_deriv() + _test_scaled_lstm()