diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/lstm.py b/egs/librispeech/ASR/lstm_transducer_stateless/lstm.py index 52424c0bb..781176ea9 100644 --- a/egs/librispeech/ASR/lstm_transducer_stateless/lstm.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless/lstm.py @@ -15,7 +15,7 @@ # limitations under the License. import copy -from typing import Tuple +from typing import Optional, Tuple import torch from encoder_interface import EncoderInterface @@ -73,7 +73,7 @@ class RNN(EncoderInterface): # (2) embedding: num_features -> d_model self.encoder_embed = Conv2dSubsampling(num_features, d_model) - self.encoder_layers = num_encoder_layers + self.num_encoder_layers = num_encoder_layers self.d_model = d_model encoder_layer = RNNEncoderLayer( @@ -119,8 +119,8 @@ class RNN(EncoderInterface): return x, lengths @torch.jit.export - def get_init_state(self, device: torch.device) -> torch.Tensor: - """Get model initial state.""" + def get_init_states(self, device: torch.device) -> torch.Tensor: + """Get model initial states.""" init_states = torch.zeros( (2, self.num_encoder_layers, self.d_model), device=device ) @@ -283,7 +283,7 @@ class RNNEncoderLayer(nn.Module): # lstm module # The required shapes of h_0 and c_0 are both (1, N, E). - src_lstm, new_states = self.lstm(src, states.unbind(dim=0)) + src_lstm, new_states = self.lstm(src, (states[0], states[1])) new_states = torch.stack(new_states, dim=0) src = src + self.dropout(src_lstm) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py index d0c16cd1e..6446b6704 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py @@ -22,7 +22,6 @@ from typing import Optional, Tuple import torch import torch.nn as nn from torch import Tensor, _VF -from torch.nn.utils.rnn import PackedSequence def _ntuple(n): @@ -428,84 +427,45 @@ class ScaledLSTM(nn.LSTM): ) return flat_weights - def forward(self, input, hx=None): - # This function is copied from https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/rnn.py # noqa - # The only change is for calling `_VF.lstm()`: + def forward( + self, input: Tensor, hx: Optional[Tuple[Tensor, Tensor]] = None + ): + # This function is modified from https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/rnn.py # noqa + # The change for calling `_VF.lstm()` is: # self._flat_weights -> self.get_flat_weights() - orig_input = input - # xxx: isinstance check needs to be in conditional for TorchScript to compile - if isinstance(orig_input, PackedSequence): - input, batch_sizes, sorted_indices, unsorted_indices = input - max_batch_size = batch_sizes[0] - max_batch_size = int(max_batch_size) - else: - batch_sizes = None - max_batch_size = ( - input.size(0) if self.batch_first else input.size(1) - ) - sorted_indices = None - unsorted_indices = None - if hx is None: - num_directions = 2 if self.bidirectional else 1 - real_hidden_size = ( - self.proj_size if self.proj_size > 0 else self.hidden_size - ) h_zeros = torch.zeros( - self.num_layers * num_directions, - max_batch_size, - real_hidden_size, + 1, + input.size(1), + self.hidden_size, dtype=input.dtype, device=input.device, ) c_zeros = torch.zeros( - self.num_layers * num_directions, - max_batch_size, + 1, + input.size(1), self.hidden_size, dtype=input.dtype, device=input.device, ) hx = (h_zeros, c_zeros) - else: - # Each batch of the hidden state should match the input sequence that - # the user believes he/she is passing in. - hx = self.permute_hidden(hx, sorted_indices) - self.check_forward_args(input, hx, batch_sizes) - if batch_sizes is None: - result = _VF.lstm( - input, - hx, - self.get_flat_weights(), - self.bias, - self.num_layers, - self.dropout, - self.training, - self.bidirectional, - self.batch_first, - ) - else: - result = _VF.lstm( - input, - batch_sizes, - hx, - self.get_flat_weights(), - self.bias, - self.num_layers, - self.dropout, - self.training, - self.bidirectional, - ) + self.check_forward_args(input, hx, None) + result = _VF.lstm( + input, + hx, + self.get_flat_weights(), + self.bias, + self.num_layers, + self.dropout, + self.training, + self.bidirectional, + self.batch_first, + ) + output = result[0] hidden = result[1:] - # xxx: isinstance check needs to be in conditional for TorchScript to compile - if isinstance(orig_input, PackedSequence): - output_packed = PackedSequence( - output, batch_sizes, sorted_indices, unsorted_indices - ) - return output_packed, self.permute_hidden(hidden, unsorted_indices) - else: - return output, self.permute_hidden(hidden, unsorted_indices) + return output, hidden class ActivationBalancer(torch.nn.Module):