From 2d53f2ef8b87f278bdd8cfeaa670fcbcccd681b9 Mon Sep 17 00:00:00 2001 From: yaozengwei Date: Sun, 17 Jul 2022 12:59:27 +0800 Subject: [PATCH] add RNN and Conv2dSubsampling classes in lstm.py --- .../ASR/lstm_transducer_stateless/lstm.py | 263 +++++++++++++++++- 1 file changed, 250 insertions(+), 13 deletions(-) diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/lstm.py b/egs/librispeech/ASR/lstm_transducer_stateless/lstm.py index a0f463797..e180d9ec6 100644 --- a/egs/librispeech/ASR/lstm_transducer_stateless/lstm.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless/lstm.py @@ -15,9 +15,7 @@ # limitations under the License. import copy -import math -import warnings -from typing import List, Optional, Tuple +from typing import Tuple import torch from encoder_interface import EncoderInterface @@ -25,12 +23,157 @@ from scaling import ( ActivationBalancer, BasicNorm, DoubleSwish, - ScaledConv1d, ScaledConv2d, ScaledLinear, ScaledLSTM, ) -from torch import Tensor, nn +from torch import nn + + +class RNN(EncoderInterface): + """ + Args: + num_features (int): + Number of input features. + subsampling_factor (int): + Subsampling factor of encoder (convolution layers before lstm layers). + d_model (int): + Hidden dimension for lstm layers, also output dimension (default=512). + dim_feedforward (int): + Feedforward dimension (default=2048). + num_encoder_layers (int): + Number of encoder layers (default=12). + dropout (float): + Dropout rate (default=0.1). + layer_dropout (float): + Dropout value for model-level warmup (default=0.075). + """ + + def __init__( + self, + num_features: int, + subsampling_factor: int, + d_model: int = 512, + dim_feedforward: int = 2048, + num_encoder_layers: int = 12, + dropout: float = 0.1, + layer_dropout: float = 0.075, + ) -> None: + super(RNN, self).__init__() + + self.num_features = num_features + self.subsampling_factor = subsampling_factor + if subsampling_factor != 4: + raise NotImplementedError("Support only 'subsampling_factor=4'.") + + # self.encoder_embed converts the input of shape (N, T, num_features) + # to the shape (N, T//subsampling_factor, d_model). + # That is, it does two things simultaneously: + # (1) subsampling: T -> T//subsampling_factor + # (2) embedding: num_features -> d_model + self.encoder_embed = Conv2dSubsampling(num_features, d_model) + + self.encoder_layers = num_encoder_layers + self.d_model = d_model + + encoder_layer = RNNEncoderLayer( + d_model, dim_feedforward, dropout, layer_dropout + ) + self.encoder = RNNEncoder(encoder_layer, num_encoder_layers) + + def forward( + self, x: torch.Tensor, x_lens: torch.Tensor, warmup: float = 1.0 + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + x: + The input tensor. Its shape is (N, T, C), where N is the batch size, + T is the sequence length, C is the feature dimension. + x_lens: + A tensor of shape (N,), containing the number of frames in `x` + before padding. + warmup: + A floating point value that gradually increases from 0 throughout + training; when it is >= 1.0 we are "fully warmed up". It is used + to turn modules on sequentially. + + Returns: + A tuple of 2 tensors: + - embeddings: its shape is (N, T', d_model), where T' is the output + sequence lengths. + - lengths: a tensor of shape (batch_size,) containing the number of + frames in `embeddings` before padding. + """ + x = self.encoder_embed(x) + x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) + + # lengths = ((x_lens - 1) // 2 - 1) // 2 # issue an warning + # + # Note: rounding_mode in torch.div() is available only in torch >= 1.8.0 + lengths = (((x_lens - 1) >> 1) - 1) >> 1 + assert x.size(0) == lengths.max().item() + + x = self.encoder(x, warmup) + + x = x.permute(1, 0, 2) # (T, N, C) -> (N, T, C) + return x, lengths + + @torch.jit.export + def get_init_state(self, device: torch.device) -> torch.Tensor: + """Get model initial state.""" + init_states = torch.zeros( + (2, self.num_encoder_layers, self.d_model), device=device + ) + return init_states + + @torch.jit.export + def infer( + self, x: torch.Tensor, x_lens: torch.Tensor, states: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Args: + x: + The input tensor. Its shape is (N, T, C), where N is the batch size, + T is the sequence length, C is the feature dimension. + x_lens: + A tensor of shape (N,), containing the number of frames in `x` + before padding. + states: + Its shape is (2, num_encoder_layers, N, E). + states[0] and states[1] are cached hidden states and cell states for + all layers, respectively. + + Returns: + A tuple of 3 tensors: + - embeddings: its shape is (N, T', d_model), where T' is the output + sequence lengths. + - lengths: a tensor of shape (batch_size,) containing the number of + frames in `embeddings` before padding. + - updated states, with shape of (2, num_encoder_layers, N, E). + """ + assert not self.training + assert states.shape == ( + 2, + self.num_encoder_layers, + x.size(0), + self.d_model, + ), states.shape + + # lengths = ((x_lens - 1) // 2 - 1) // 2 # issue an warning + # + # Note: rounding_mode in torch.div() is available only in torch >= 1.8.0 + lengths = (((x_lens - 1) >> 1) - 1) >> 1 + # we will cut off 1 frame on each side of encoder_embed output + lengths -= 2 + + embed = self.encoder_embed(x) + embed = embed[:, 1:-1, :] + embed = embed.permute(1, 0, 2) # (N, T, C) -> (T, N, C) + + x, states = self.encoder.infer(embed, states) + + x = x.permute(1, 0, 2) # (T, N, C) -> (N, T, C) + return x, lengths, states class RNNEncoderLayer(nn.Module): @@ -77,7 +220,7 @@ class RNNEncoderLayer(nn.Module): ) self.dropout = nn.Dropout(dropout) - def forward(self, src: Tensor, warmup: float = 1.0) -> Tensor: + def forward(self, src: torch.Tensor, warmup: float = 1.0) -> torch.Tensor: """ Pass the input through the encoder layer. @@ -120,8 +263,8 @@ class RNNEncoderLayer(nn.Module): @torch.jit.export def infer( - self, src: Tensor, states: Tuple[Tensor] - ) -> Tuple[Tensor, Tuple[Tensor]]: + self, src: torch.Tensor, states: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: """ Pass the input through the encoder layer. @@ -139,9 +282,9 @@ class RNNEncoderLayer(nn.Module): assert states.shape == (2, 1, src.size(1), src.size(2)) # lstm module - # The required shapes of h_0 and c_0 are both (1, N, E) + # The required shapes of h_0 and c_0 are both (1, N, E). src_lstm, new_states = self.lstm(src, states.unbind(dim=0)) - new_states = torch.stack(states, dim=0) + new_states = torch.stack(new_states, dim=0) src = src + self.dropout(src_lstm) # feed forward module @@ -170,7 +313,7 @@ class RNNEncoder(nn.Module): ) self.num_layers = num_layers - def forward(self, src: Tensor, warmup: float = 1.0) -> Tensor: + def forward(self, src: torch.Tensor, warmup: float = 1.0) -> torch.Tensor: """ Pass the input through the encoder layer in turn. @@ -192,8 +335,8 @@ class RNNEncoder(nn.Module): @torch.jit.export def infer( - self, src: Tensor, states: Tuple[Tensor] - ) -> Tuple[Tensor, Tuple[Tensor]]: + self, src: torch.Tensor, states: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: """ Pass the input through the encoder layer. @@ -220,3 +363,97 @@ class RNNEncoder(nn.Module): new_states_list.append(new_states) return output, torch.cat(new_states_list, dim=1) + + +class Conv2dSubsampling(nn.Module): + """Convolutional 2D subsampling (to 1/4 length). + + Convert an input of shape (N, T, idim) to an output + with shape (N, T', odim), where + T' = ((T-1)//2 - 1)//2, which approximates T' == T//4 + + It is based on + https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/subsampling.py # noqa + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + layer1_channels: int = 8, + layer2_channels: int = 32, + layer3_channels: int = 128, + ) -> None: + """ + Args: + in_channels: + Number of channels in. The input shape is (N, T, in_channels). + Caution: It requires: T >=7, in_channels >=7 + out_channels + Output dim. The output shape is (N, ((T-1)//2 - 1)//2, out_channels) + layer1_channels: + Number of channels in layer1 + layer1_channels: + Number of channels in layer2 + """ + assert in_channels >= 7 + super().__init__() + + self.conv = nn.Sequential( + ScaledConv2d( + in_channels=1, + out_channels=layer1_channels, + kernel_size=3, + padding=1, + ), + ActivationBalancer(channel_dim=1), + DoubleSwish(), + ScaledConv2d( + in_channels=layer1_channels, + out_channels=layer2_channels, + kernel_size=3, + stride=2, + ), + ActivationBalancer(channel_dim=1), + DoubleSwish(), + ScaledConv2d( + in_channels=layer2_channels, + out_channels=layer3_channels, + kernel_size=3, + stride=2, + ), + ActivationBalancer(channel_dim=1), + DoubleSwish(), + ) + self.out = ScaledLinear( + layer3_channels * (((in_channels - 1) // 2 - 1) // 2), out_channels + ) + # set learn_eps=False because out_norm is preceded by `out`, and `out` + # itself has learned scale, so the extra degree of freedom is not + # needed. + self.out_norm = BasicNorm(out_channels, learn_eps=False) + # constrain median of output to be close to zero. + self.out_balancer = ActivationBalancer( + channel_dim=-1, min_positive=0.45, max_positive=0.55 + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Subsample x. + + Args: + x: + Its shape is (N, T, idim). + + Returns: + Return a tensor of shape (N, ((T-1)//2 - 1)//2, odim) + """ + # On entry, x is (N, T, idim) + x = x.unsqueeze(1) # (N, T, idim) -> (N, 1, T, idim) i.e., (N, C, H, W) + x = self.conv(x) + # Now x is of shape (N, odim, ((T-1)//2 - 1)//2, ((idim-1)//2 - 1)//2) + b, c, t, f = x.size() + x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f)) + # Now x is of shape (N, ((T-1)//2 - 1))//2, odim) + x = self.out_norm(x) + x = self.out_balancer(x) + return x