diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py index 9c8302926..6b625513e 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py @@ -16,6 +16,7 @@ # limitations under the License. import copy +from encoder_interface import EncoderInterface import math import warnings from typing import Optional, Tuple, Sequence @@ -23,12 +24,11 @@ from scaling import DoubleSwish, ActivationBalancer, BasicNorm, ScaledLinear, Sc import torch from torch import Tensor, nn -from transformer import Transformer from icefall.utils import make_pad_mask -class Conformer(Transformer): +class Conformer(EncoderInterface): """ Args: num_features (int): Number of input features @@ -40,7 +40,6 @@ class Conformer(Transformer): num_encoder_layers (int): number of encoder layers dropout (float): dropout rate cnn_module_kernel (int): Kernel size of convolution module - normalize_before (bool): whether to use layer_norm before the first block. vgg_frontend (bool): whether to use vgg frontend. """ @@ -55,22 +54,22 @@ class Conformer(Transformer): num_encoder_layers: int = 12, dropout: float = 0.1, cnn_module_kernel: int = 31, - normalize_before: bool = True, - vgg_frontend: bool = False, aux_layer_period: int = 3 ) -> None: - super(Conformer, self).__init__( - num_features=num_features, - output_dim=output_dim, - subsampling_factor=subsampling_factor, - d_model=d_model, - nhead=nhead, - dim_feedforward=dim_feedforward, - num_encoder_layers=num_encoder_layers, - dropout=dropout, - normalize_before=normalize_before, - vgg_frontend=vgg_frontend, - ) + super(Conformer, self).__init__() + + self.num_features = num_features + self.output_dim = output_dim + self.subsampling_factor = subsampling_factor + if subsampling_factor != 4: + raise NotImplementedError("Support only 'subsampling_factor=4'.") + + # self.encoder_embed converts the input of shape (N, T, num_features) + # to the shape (N, T//subsampling_factor, d_model). + # That is, it does two things simultaneously: + # (1) subsampling: T -> T//subsampling_factor + # (2) embedding: num_features -> d_model + self.encoder_embed = Conv2dSubsampling(num_features, 128, d_model) self.encoder_pos = RelPositionalEncoding(d_model, dropout) @@ -80,11 +79,13 @@ class Conformer(Transformer): dim_feedforward, dropout, cnn_module_kernel, - normalize_before, ) self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers, aux_layers=list(range(0, num_encoder_layers-1, aux_layer_period))) - self.normalize_before = normalize_before + + self.encoder_output_layer = nn.Sequential( + nn.Dropout(p=dropout), ScaledLinear(d_model, output_dim) + ) def forward( @@ -136,7 +137,6 @@ class ConformerEncoderLayer(nn.Module): dim_feedforward: the dimension of the feedforward network model (default=2048). dropout: the dropout value (default=0.1). cnn_module_kernel (int): Kernel size of convolution module. - normalize_before: whether to use layer_norm before the first block. Examples:: >>> encoder_layer = ConformerEncoderLayer(d_model=512, nhead=8) @@ -152,7 +152,6 @@ class ConformerEncoderLayer(nn.Module): dim_feedforward: int = 2048, dropout: float = 0.1, cnn_module_kernel: int = 31, - normalize_before: bool = True, ) -> None: super(ConformerEncoderLayer, self).__init__() self.d_model = d_model @@ -942,6 +941,80 @@ class Identity(torch.nn.Module): return x +class Conv2dSubsampling(nn.Module): + """Convolutional 2D subsampling (to 1/4 length). + + Convert an input of shape (N, T, idim) to an output + with shape (N, T', odim), where + T' = ((T-1)//2 - 1)//2, which approximates T' == T//4 + + It is based on + https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/subsampling.py # noqa + """ + + def __init__(self, in_channels: int, + out_channels: int, + layer1_channels: int = 64, + layer2_channels: int = 128) -> None: + """ + Args: + in_channels: + Number of channels in. The input shape is (N, T, in_channels). + Caution: It requires: T >=7, in_channels >=7 + out_channels + Output dim. The output shape is (N, ((T-1)//2 - 1)//2, out_channels) + layer1_channels: + Number of channels in layer1 + layer1_channels: + Number of channels in layer2 + """ + assert in_channels >= 7 + super().__init__() + self.conv = nn.Sequential( + ScaledConv2d( + in_channels=1, out_channels=layer1_channels, + kernel_size=3, stride=2 + ), + ActivationBalancer(channel_dim=1), + DoubleSwish(), + ScaledConv2d( + in_channels=layer1_channels, out_channels=layer2_channels, + kernel_size=3, stride=2 + ), + ActivationBalancer(channel_dim=1), + DoubleSwish(), + ) + self.out = ScaledLinear(layer2_channels * (((in_channels - 1) // 2 - 1) // 2), out_channels) + # set learn_eps=False because out_norm is preceded by `out`, and `out` + # itself has learned scale, so the extra degree of freedom is not + # needed. + self.out_norm = BasicNorm(out_channels, learn_eps=False) + # constrain median of output to be close to zero. + self.out_balancer = ActivationBalancer(channel_dim=-1, + min_positive=0.45, + max_positive=0.55) + + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Subsample x. + + Args: + x: + Its shape is (N, T, idim). + + Returns: + Return a tensor of shape (N, ((T-1)//2 - 1)//2, odim) + """ + # On entry, x is (N, T, idim) + x = x.unsqueeze(1) # (N, T, idim) -> (N, 1, T, idim) i.e., (N, C, H, W) + x = self.conv(x) + # Now x is of shape (N, odim, ((T-1)//2 - 1)//2, ((idim-1)//2 - 1)//2) + b, c, t, f = x.size() + x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f)) + # Now x is of shape (N, ((T-1)//2 - 1))//2, odim) + x = self.out_norm(x) + x = self.out_balancer(x) + return x if __name__ == '__main__': diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/subsampling.py b/egs/librispeech/ASR/pruned_transducer_stateless2/subsampling.py index c2da23adc..12ca09a17 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/subsampling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/subsampling.py @@ -32,34 +32,43 @@ class Conv2dSubsampling(nn.Module): https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/subsampling.py # noqa """ - def __init__(self, idim: int, odim: int) -> None: + def __init__(self, in_channels: int, + out_channels: int, + layer1_channels: int = 64, + layer2_channels: int = 128) -> None: """ Args: - idim: - Input dim. The input shape is (N, T, idim). - Caution: It requires: T >=7, idim >=7 - odim: - Output dim. The output shape is (N, ((T-1)//2 - 1)//2, odim) + in_channels: + Number of channels in. The input shape is (N, T, in_channels). + Caution: It requires: T >=7, in_channels >=7 + out_channels + Output dim. The output shape is (N, ((T-1)//2 - 1)//2, out_channels) + layer1_channels: + Number of channels in layer1 + layer1_channels: + Number of channels in layer2 """ - assert idim >= 7 + assert in_channels >= 7 super().__init__() self.conv = nn.Sequential( ScaledConv2d( - in_channels=1, out_channels=odim, kernel_size=3, stride=2 + in_channels=1, out_channels=layer1_channels, + kernel_size=3, stride=2 ), ActivationBalancer(channel_dim=1), DoubleSwish(), ScaledConv2d( - in_channels=odim, out_channels=odim, kernel_size=3, stride=2 + in_channels=layer1_channels, out_channels=layer2_channels, + kernel_size=3, stride=2 ), ActivationBalancer(channel_dim=1), DoubleSwish(), ) - self.out = ScaledLinear(odim * (((idim - 1) // 2 - 1) // 2), odim) + self.out = ScaledLinear(layer2_channels * (((in_channels - 1) // 2 - 1) // 2), out_channels) # set learn_eps=False because out_norm is preceded by `out`, and `out` # itself has learned scale, so the extra degree of freedom is not # needed. - self.out_norm = BasicNorm(odim, learn_eps=False) + self.out_norm = BasicNorm(out_channels, learn_eps=False) # constrain median of output to be close to zero. self.out_balancer = ActivationBalancer(channel_dim=-1, min_positive=0.45, @@ -86,99 +95,3 @@ class Conv2dSubsampling(nn.Module): x = self.out_norm(x) x = self.out_balancer(x) return x - - -class VggSubsampling(nn.Module): - """Trying to follow the setup described in the following paper: - https://arxiv.org/pdf/1910.09799.pdf - - This paper is not 100% explicit so I am guessing to some extent, - and trying to compare with other VGG implementations. - - Convert an input of shape (N, T, idim) to an output - with shape (N, T', odim), where - T' = ((T-1)//2 - 1)//2, which approximates T' = T//4 - """ - - def __init__(self, idim: int, odim: int) -> None: - """Construct a VggSubsampling object. - - This uses 2 VGG blocks with 2 Conv2d layers each, - subsampling its input by a factor of 4 in the time dimensions. - - Args: - idim: - Input dim. The input shape is (N, T, idim). - Caution: It requires: T >=7, idim >=7 - odim: - Output dim. The output shape is (N, ((T-1)//2 - 1)//2, odim) - """ - super().__init__() - - cur_channels = 1 - layers = [] - block_dims = [32, 64] - - # The decision to use padding=1 for the 1st convolution, then padding=0 - # for the 2nd and for the max-pooling, and ceil_mode=True, was driven by - # a back-compatibility concern so that the number of frames at the - # output would be equal to: - # (((T-1)//2)-1)//2. - # We can consider changing this by using padding=1 on the - # 2nd convolution, so the num-frames at the output would be T//4. - for block_dim in block_dims: - layers.append( - torch.nn.Conv2d( - in_channels=cur_channels, - out_channels=block_dim, - kernel_size=3, - padding=1, - stride=1, - ) - ) - layers.append(torch.nn.ReLU()) - layers.append( - torch.nn.Conv2d( - in_channels=block_dim, - out_channels=block_dim, - kernel_size=3, - padding=0, - stride=1, - ) - ) - layers.append( - torch.nn.MaxPool2d( - kernel_size=2, stride=2, padding=0, ceil_mode=True - ) - ) - cur_channels = block_dim - - self.layers = nn.Sequential(*layers) - - self.out = nn.Linear( - block_dims[-1] * (((idim - 1) // 2 - 1) // 2), odim - ) - self.out_norm = BasicNorm(odim, learn_eps=False) - # constrain median of output to be close to zero. - self.out_balancer = ActivationBalancer(channel_dim=-1, - min_positive=0.45, - max_positive=0.55) - - - def forward(self, x: torch.Tensor) -> torch.Tensor: - """Subsample x. - - Args: - x: - Its shape is (N, T, idim). - - Returns: - Return a tensor of shape (N, ((T-1)//2 - 1)//2, odim) - """ - x = x.unsqueeze(1) - x = self.layers(x) - b, c, t, f = x.size() - x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f)) - x = self.out_norm(x) - x = self.out_balancer(x) - return x diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py index c1e836903..237eb8bbd 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py @@ -291,7 +291,6 @@ def get_params() -> AttributeDict: "nhead": 8, "dim_feedforward": 2048, "num_encoder_layers": 12, - "vgg_frontend": False, # parameters for decoder "embedding_dim": 512, # parameters for Noam @@ -314,7 +313,6 @@ def get_encoder_model(params: AttributeDict) -> nn.Module: nhead=params.nhead, dim_feedforward=params.dim_feedforward, num_encoder_layers=params.num_encoder_layers, - vgg_frontend=params.vgg_frontend, ) return encoder diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/transformer.py b/egs/librispeech/ASR/pruned_transducer_stateless2/transformer.py index aa091877c..a58702e1d 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/transformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/transformer.py @@ -78,10 +78,7 @@ class Transformer(EncoderInterface): # That is, it does two things simultaneously: # (1) subsampling: T -> T//subsampling_factor # (2) embedding: num_features -> d_model - if vgg_frontend: - self.encoder_embed = VggSubsampling(num_features, d_model) - else: - self.encoder_embed = Conv2dSubsampling(num_features, d_model) + self.encoder_embed = Conv2dSubsampling(num_features, d_model) self.encoder_pos = PositionalEncoding(d_model, dropout)