diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/lstm.py b/egs/librispeech/ASR/lstm_transducer_stateless/lstm.py index 45409ccea..1d6fda0b4 100644 --- a/egs/librispeech/ASR/lstm_transducer_stateless/lstm.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless/lstm.py @@ -38,9 +38,11 @@ class RNN(EncoderInterface): subsampling_factor (int): Subsampling factor of encoder (convolution layers before lstm layers) (default=4). # noqa d_model (int): - Hidden dimension for lstm layers, also output dimension (default=512). + Output dimension (default=512). dim_feedforward (int): Feedforward dimension (default=2048). + rnn_hidden_size (int): + Hidden dimension for lstm layers (default=1024). num_encoder_layers (int): Number of encoder layers (default=12). dropout (float): @@ -58,6 +60,7 @@ class RNN(EncoderInterface): subsampling_factor: int = 4, d_model: int = 512, dim_feedforward: int = 2048, + rnn_hidden_size: int = 1024, num_encoder_layers: int = 12, dropout: float = 0.1, layer_dropout: float = 0.075, @@ -79,9 +82,14 @@ class RNN(EncoderInterface): self.num_encoder_layers = num_encoder_layers self.d_model = d_model + self.rnn_hidden_size = rnn_hidden_size encoder_layer = RNNEncoderLayer( - d_model, dim_feedforward, dropout, layer_dropout + d_model=d_model, + dim_feedforward=dim_feedforward, + rnn_hidden_size=rnn_hidden_size, + dropout=dropout, + layer_dropout=layer_dropout, ) self.encoder = RNNEncoder( encoder_layer, @@ -135,17 +143,26 @@ class RNN(EncoderInterface): return x, lengths @torch.jit.export - def get_init_states(self, device: torch.device) -> torch.Tensor: + def get_init_states( + self, device: torch.device + ) -> Tuple[torch.Tensor, torch.Tensor]: """Get model initial states.""" - init_states = torch.zeros( - (2, self.num_encoder_layers, self.d_model), device=device + # for rnn hidden states + hidden_states = torch.zeros( + (self.num_encoder_layers, self.d_model), device=device ) - return init_states + cell_states = torch.zeros( + (self.num_encoder_layers, self.rnn_hidden_size), device=device + ) + return (hidden_states, cell_states) @torch.jit.export def infer( - self, x: torch.Tensor, x_lens: torch.Tensor, states: torch.Tensor - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + self, + x: torch.Tensor, + x_lens: torch.Tensor, + states: Tuple[torch.Tensor, torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: """ Args: x: @@ -155,9 +172,11 @@ class RNN(EncoderInterface): A tensor of shape (N,), containing the number of frames in `x` before padding. states: - Its shape is (2, num_encoder_layers, N, E). - states[0] and states[1] are cached hidden states and cell states for - all layers, respectively. + It is a list of 2 tensors. + states[0] is the hidden states of all layers, + with shape of (num_layers, N, d_model); + states[1] is the cell states of all layers, + with shape of (num_layers, N, rnn_hidden_size). Returns: A tuple of 3 tensors: @@ -165,15 +184,22 @@ class RNN(EncoderInterface): sequence lengths. - lengths: a tensor of shape (batch_size,) containing the number of frames in `embeddings` before padding. - - updated states, with shape of (2, num_encoder_layers, N, E). + - updated states, whose shape is same as the input states. """ assert not self.training - assert states.shape == ( - 2, + assert len(states) == 2 + # for hidden state + assert states[0].shape == ( self.num_encoder_layers, x.size(0), self.d_model, - ), states.shape + ) + # for cell state + assert states[1].shape == ( + self.num_encoder_layers, + x.size(0), + self.rnn_hidden_size, + ) # lengths = ((x_lens - 1) // 2 - 1) // 2 # issue an warning # @@ -201,6 +227,8 @@ class RNNEncoderLayer(nn.Module): The number of expected features in the input (required). dim_feedforward: The dimension of feedforward network model (default=2048). + rnn_hidden_size: + The hidden dimension of rnn layer. dropout: The dropout value (default=0.1). layer_dropout: @@ -211,15 +239,22 @@ class RNNEncoderLayer(nn.Module): self, d_model: int, dim_feedforward: int, + rnn_hidden_size: int, dropout: float = 0.1, layer_dropout: float = 0.075, ) -> None: super(RNNEncoderLayer, self).__init__() self.layer_dropout = layer_dropout self.d_model = d_model + self.rnn_hidden_size = rnn_hidden_size + assert rnn_hidden_size >= d_model self.lstm = ScaledLSTM( - input_size=d_model, hidden_size=d_model, dropout=0.0 + input_size=d_model, + hidden_size=rnn_hidden_size, + proj_size=d_model if rnn_hidden_size > d_model else 0, + num_layers=1, + dropout=0.0, ) self.feed_forward = nn.Sequential( ScaledLinear(d_model, dim_feedforward), @@ -279,28 +314,30 @@ class RNNEncoderLayer(nn.Module): @torch.jit.export def infer( - self, src: torch.Tensor, states: torch.Tensor - ) -> Tuple[torch.Tensor, torch.Tensor]: + self, src: torch.Tensor, states: Tuple[torch.Tensor, torch.Tensor] + ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: """ Pass the input through the encoder layer. Args: src: The sequence to the encoder layer (required). - Its shape is (S, N, E), where S is the sequence length, - N is the batch size, and E is the feature number. + Its shape is (S, N, d_model), where S is the sequence length, + N is the batch size. states: - Its shape is (2, 1, N, E). - states[0] and states[1] are cached hidden state and cell state, - respectively. + It is a tuple of 2 tensors. + states[0] is the hidden state, with shape of (1, N, d_model); + states[1] is the cell state, with shape of (1, N, rnn_hidden_size). """ assert not self.training - assert states.shape == (2, 1, src.size(1), src.size(2)) + assert len(states) == 2 + # for hidden state + assert states[0].shape == (1, src.size(1), self.d_model) + # for cell state + assert states[1].shape == (1, src.size(1), self.rnn_hidden_size) # lstm module - # The required shapes of h_0 and c_0 are both (1, N, E). - src_lstm, new_states = self.lstm(src, (states[0], states[1])) - new_states = torch.stack(new_states, dim=0) + src_lstm, new_states = self.lstm(src, states) src = src + self.dropout(src_lstm) # feed forward module @@ -333,6 +370,8 @@ class RNNEncoder(nn.Module): [copy.deepcopy(encoder_layer) for i in range(num_layers)] ) self.num_layers = num_layers + self.d_model = encoder_layer.d_model + self.rnn_hidden_size = encoder_layer.rnn_hidden_size self.use_random_combiner = False if aux_layers is not None: @@ -377,34 +416,55 @@ class RNNEncoder(nn.Module): @torch.jit.export def infer( - self, src: torch.Tensor, states: torch.Tensor - ) -> Tuple[torch.Tensor, torch.Tensor]: + self, src: torch.Tensor, states: Tuple[torch.Tensor, torch.Tensor] + ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: """ Pass the input through the encoder layer. Args: src: The sequence to the encoder layer (required). - Its shape is (S, N, E), where S is the sequence length, - N is the batch size, and E is the feature number. + Its shape is (S, N, d_model), where S is the sequence length, + N is the batch size. states: - Its shape is (2, num_layers, N, E). - states[0] and states[1] are cached hidden states and cell states for - all layers, respectively. + It is a list of 2 tensors. + states[0] is the hidden states of all layers, + with shape of (num_layers, N, d_model); + states[1] is the cell states of all layers, + with shape of (num_layers, N, rnn_hidden_size). """ assert not self.training - assert states.shape == (2, self.num_layers, src.size(1), src.size(2)) + assert len(states) == 2 + # for hidden state + assert states[0].shape == (self.num_layers, src.size(1), self.d_model) + # for cell state + assert states[1].shape == ( + self.num_layers, + src.size(1), + self.rnn_hidden_size, + ) - new_states_list = [] output = src + new_hidden_states = [] + new_cell_states = [] for layer_index, mod in enumerate(self.layers): - # new_states: (2, 1, N, E) - output, new_states = mod.infer( - output, states[:, layer_index : layer_index + 1, :, :] + layer_states = ( + states[0][ + layer_index : layer_index + 1, :, : + ], # h: (1, N, d_model) + states[1][ + layer_index : layer_index + 1, :, : + ], # c: (1, N, rnn_hidden_size) ) - new_states_list.append(new_states) + output, (h, c) = mod.infer(output, layer_states) + new_hidden_states.append(h) + new_cell_states.append(c) - return output, torch.cat(new_states_list, dim=1) + new_states = ( + torch.cat(new_hidden_states, dim=0), + torch.cat(new_cell_states, dim=0), + ) + return output, new_states class Conv2dSubsampling(nn.Module): @@ -740,8 +800,14 @@ def _test_random_combine_main(): if __name__ == "__main__": - feature_dim = 50 - m = RNN(num_features=feature_dim, d_model=128) + feature_dim = 80 + m = RNN( + num_features=feature_dim, + d_model=512, + rnn_hidden_size=1024, + dim_feedforward=2048, + num_encoder_layers=12, + ) batch_size = 5 seq_len = 20 # Just make sure the forward pass runs. @@ -750,5 +816,7 @@ if __name__ == "__main__": torch.full((batch_size,), seq_len, dtype=torch.int64), warmup=0.5, ) + num_param = sum([p.numel() for p in m.parameters()]) + print(f"Number of model parameters: {num_param}") _test_random_combine_main() diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/train.py b/egs/librispeech/ASR/lstm_transducer_stateless/train.py index 0826c72e9..89bd406b1 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless/train.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless/train.py @@ -90,16 +90,30 @@ def add_model_arguments(parser: argparse.ArgumentParser): parser.add_argument( "--num-encoder-layers", type=int, - default=20, + default=12, help="Number of RNN encoder layers..", ) + parser.add_argument( + "--encoder-dim", + type=int, + default=512, + help="Encoder output dimesion.", + ) + + parser.add_argument( + "--rnn-hidden-size", + type=int, + default=1024, + help="Hidden dim for LSTM layers.", + ) + parser.add_argument( "--aux-layer-period", type=int, default=3, help="""Peroid of auxiliary layers used for randomly combined during training. - If not larger than 0, will not use the random combiner. + If not larger than 0 (e.g., -1), will not use the random combiner. """, ) @@ -340,8 +354,6 @@ def get_params() -> AttributeDict: - subsampling_factor: The subsampling factor for the model. - - encoder_dim: Hidden dim for multi-head attention model. - - num_decoder_layers: Number of decoder layer of transformer decoder. - warm_step: The warm_step for Noam optimizer. @@ -359,7 +371,6 @@ def get_params() -> AttributeDict: # parameters for conformer "feature_dim": 80, "subsampling_factor": 4, - "encoder_dim": 512, "dim_feedforward": 2048, # parameters for decoder "decoder_dim": 512, @@ -380,6 +391,7 @@ def get_encoder_model(params: AttributeDict) -> nn.Module: num_features=params.feature_dim, subsampling_factor=params.subsampling_factor, d_model=params.encoder_dim, + rnn_hidden_size=params.rnn_hidden_size, dim_feedforward=params.dim_feedforward, num_encoder_layers=params.num_encoder_layers, aux_layer_period=params.aux_layer_period, @@ -837,7 +849,7 @@ def run(rank, world_size, args): params = get_params() params.update(vars(args)) if params.full_libri is False: - params.valid_interval = 1600 + params.valid_interval = 800 fix_random_seed(params.seed) if world_size > 1: @@ -903,6 +915,10 @@ def run(rank, world_size, args): logging.info("Loading scheduler state dict") scheduler.load_state_dict(checkpoints["scheduler"]) + # # overwrite it + # scheduler.base_lrs = [params.initial_lr for _ in scheduler.base_lrs] + # print(scheduler.base_lrs) + if params.print_diagnostics: diagnostic = diagnostics.attach_diagnostics(model) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py index 5f0785d91..560867c3b 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py @@ -379,7 +379,7 @@ class ScaledConv2d(nn.Conv2d): class ScaledLSTM(nn.LSTM): # See docs for ScaledLinear. - # This class implements single-layer LSTM with scaling mechanism, using `torch._VF.lstm` + # This class implements LSTM with scaling mechanism, using `torch._VF.lstm` # Please refer to https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/rnn.py def __init__( self, @@ -388,10 +388,8 @@ class ScaledLSTM(nn.LSTM): initial_speed: float = 1.0, **kwargs ): - # Hardcode num_layers=1, bidirectional=False, proj_size=0 here - super(ScaledLSTM, self).__init__( - *args, num_layers=1, bidirectional=False, proj_size=0, **kwargs - ) + # Hardcode bidirectional=False + super(ScaledLSTM, self).__init__(*args, bidirectional=False, **kwargs) initial_scale = torch.tensor(initial_scale).log() self._scales_names = [] self._scales = [] @@ -495,14 +493,14 @@ class ScaledLSTM(nn.LSTM): # self._flat_weights -> self._get_flat_weights() if hx is None: h_zeros = torch.zeros( - 1, + self.num_layers, input.size(1), - self.hidden_size, + self.proj_size if self.proj_size > 0 else self.hidden_size, dtype=input.dtype, device=input.device, ) c_zeros = torch.zeros( - 1, + self.num_layers, input.size(1), self.hidden_size, dtype=input.dtype,