diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py index 6446b6704..fe265eb81 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py @@ -17,12 +17,14 @@ import collections from itertools import repeat -from typing import Optional, Tuple +from typing import List, Optional, Tuple import torch import torch.nn as nn from torch import Tensor, _VF +import torch.backends.cudnn.rnn as rnn + def _ntuple(n): def parse(x): @@ -419,12 +421,72 @@ class ScaledLSTM(nn.LSTM): elif "bias" in name: nn.init.constant_(self._flat_weights[idx], 0.0) - def get_flat_weights(self): + def _flatten_parameters(self, flat_weights) -> None: + """Resets parameter data pointer so that they can use faster code paths. + + Right now, this works only if the module is on the GPU and cuDNN is enabled. + Otherwise, it's a no-op. + + This function is modified from https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/rnn.py # noqa + """ + # Short-circuits if _flat_weights is only partially instantiated + if len(flat_weights) != len(self._flat_weights_names): + return + + for w in flat_weights: + if not isinstance(w, Tensor): + return + # Short-circuits if any tensor in flat_weights is not acceptable to cuDNN + # or the tensors in flat_weights are of different dtypes + + first_fw = flat_weights[0] + dtype = first_fw.dtype + for fw in flat_weights: + if ( + not isinstance(fw.data, Tensor) + or not (fw.data.dtype == dtype) + or not fw.data.is_cuda + or not torch.backends.cudnn.is_acceptable(fw.data) + ): + return + + # If any parameters alias, we fall back to the slower, copying code path. This is + # a sufficient check, because overlapping parameter buffers that don't completely + # alias would break the assumptions of the uniqueness check in + # Module.named_parameters(). + unique_data_ptrs = set(p.data_ptr() for p in flat_weights) + if len(unique_data_ptrs) != len(flat_weights): + return + + with torch.cuda.device_of(first_fw): + + # Note: no_grad() is necessary since _cudnn_rnn_flatten_weight is + # an inplace operation on self._flat_weights + with torch.no_grad(): + if torch._use_cudnn_rnn_flatten_weight(): + num_weights = 4 if self.bias else 2 + if self.proj_size > 0: + num_weights += 1 + torch._cudnn_rnn_flatten_weight( + flat_weights, + num_weights, + self.input_size, + rnn.get_cudnn_mode(self.mode), + self.hidden_size, + self.proj_size, + self.num_layers, + self.batch_first, + bool(self.bidirectional), + ) + + def _get_flat_weights(self): + """Get scaled weights, and resets their data pointer.""" flat_weights = [] for idx in range(len(self._flat_weights_names)): flat_weights.append( self._flat_weights[idx] * self._scales[idx].exp() ) + self._flatten_parameters(flat_weights) return flat_weights def forward( @@ -432,7 +494,7 @@ class ScaledLSTM(nn.LSTM): ): # This function is modified from https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/rnn.py # noqa # The change for calling `_VF.lstm()` is: - # self._flat_weights -> self.get_flat_weights() + # self._flat_weights -> self._get_flat_weights() if hx is None: h_zeros = torch.zeros( 1, @@ -454,7 +516,7 @@ class ScaledLSTM(nn.LSTM): result = _VF.lstm( input, hx, - self.get_flat_weights(), + self._get_flat_weights(), self.bias, self.num_layers, self.dropout,