Refactoring and simplifying conformer and frontend

This commit is contained in:
Daniel Povey 2022-03-29 20:32:14 +08:00
parent 57f943b25c
commit 11124b03ea
4 changed files with 115 additions and 134 deletions

View File

@ -16,6 +16,7 @@
# limitations under the License.
import copy
from encoder_interface import EncoderInterface
import math
import warnings
from typing import Optional, Tuple, Sequence
@ -23,12 +24,11 @@ from scaling import DoubleSwish, ActivationBalancer, BasicNorm, ScaledLinear, Sc
import torch
from torch import Tensor, nn
from transformer import Transformer
from icefall.utils import make_pad_mask
class Conformer(Transformer):
class Conformer(EncoderInterface):
"""
Args:
num_features (int): Number of input features
@ -40,7 +40,6 @@ class Conformer(Transformer):
num_encoder_layers (int): number of encoder layers
dropout (float): dropout rate
cnn_module_kernel (int): Kernel size of convolution module
normalize_before (bool): whether to use layer_norm before the first block.
vgg_frontend (bool): whether to use vgg frontend.
"""
@ -55,22 +54,22 @@ class Conformer(Transformer):
num_encoder_layers: int = 12,
dropout: float = 0.1,
cnn_module_kernel: int = 31,
normalize_before: bool = True,
vgg_frontend: bool = False,
aux_layer_period: int = 3
) -> None:
super(Conformer, self).__init__(
num_features=num_features,
output_dim=output_dim,
subsampling_factor=subsampling_factor,
d_model=d_model,
nhead=nhead,
dim_feedforward=dim_feedforward,
num_encoder_layers=num_encoder_layers,
dropout=dropout,
normalize_before=normalize_before,
vgg_frontend=vgg_frontend,
)
super(Conformer, self).__init__()
self.num_features = num_features
self.output_dim = output_dim
self.subsampling_factor = subsampling_factor
if subsampling_factor != 4:
raise NotImplementedError("Support only 'subsampling_factor=4'.")
# self.encoder_embed converts the input of shape (N, T, num_features)
# to the shape (N, T//subsampling_factor, d_model).
# That is, it does two things simultaneously:
# (1) subsampling: T -> T//subsampling_factor
# (2) embedding: num_features -> d_model
self.encoder_embed = Conv2dSubsampling(num_features, 128, d_model)
self.encoder_pos = RelPositionalEncoding(d_model, dropout)
@ -80,11 +79,13 @@ class Conformer(Transformer):
dim_feedforward,
dropout,
cnn_module_kernel,
normalize_before,
)
self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers,
aux_layers=list(range(0, num_encoder_layers-1, aux_layer_period)))
self.normalize_before = normalize_before
self.encoder_output_layer = nn.Sequential(
nn.Dropout(p=dropout), ScaledLinear(d_model, output_dim)
)
def forward(
@ -136,7 +137,6 @@ class ConformerEncoderLayer(nn.Module):
dim_feedforward: the dimension of the feedforward network model (default=2048).
dropout: the dropout value (default=0.1).
cnn_module_kernel (int): Kernel size of convolution module.
normalize_before: whether to use layer_norm before the first block.
Examples::
>>> encoder_layer = ConformerEncoderLayer(d_model=512, nhead=8)
@ -152,7 +152,6 @@ class ConformerEncoderLayer(nn.Module):
dim_feedforward: int = 2048,
dropout: float = 0.1,
cnn_module_kernel: int = 31,
normalize_before: bool = True,
) -> None:
super(ConformerEncoderLayer, self).__init__()
self.d_model = d_model
@ -942,6 +941,80 @@ class Identity(torch.nn.Module):
return x
class Conv2dSubsampling(nn.Module):
"""Convolutional 2D subsampling (to 1/4 length).
Convert an input of shape (N, T, idim) to an output
with shape (N, T', odim), where
T' = ((T-1)//2 - 1)//2, which approximates T' == T//4
It is based on
https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/subsampling.py # noqa
"""
def __init__(self, in_channels: int,
out_channels: int,
layer1_channels: int = 64,
layer2_channels: int = 128) -> None:
"""
Args:
in_channels:
Number of channels in. The input shape is (N, T, in_channels).
Caution: It requires: T >=7, in_channels >=7
out_channels
Output dim. The output shape is (N, ((T-1)//2 - 1)//2, out_channels)
layer1_channels:
Number of channels in layer1
layer1_channels:
Number of channels in layer2
"""
assert in_channels >= 7
super().__init__()
self.conv = nn.Sequential(
ScaledConv2d(
in_channels=1, out_channels=layer1_channels,
kernel_size=3, stride=2
),
ActivationBalancer(channel_dim=1),
DoubleSwish(),
ScaledConv2d(
in_channels=layer1_channels, out_channels=layer2_channels,
kernel_size=3, stride=2
),
ActivationBalancer(channel_dim=1),
DoubleSwish(),
)
self.out = ScaledLinear(layer2_channels * (((in_channels - 1) // 2 - 1) // 2), out_channels)
# set learn_eps=False because out_norm is preceded by `out`, and `out`
# itself has learned scale, so the extra degree of freedom is not
# needed.
self.out_norm = BasicNorm(out_channels, learn_eps=False)
# constrain median of output to be close to zero.
self.out_balancer = ActivationBalancer(channel_dim=-1,
min_positive=0.45,
max_positive=0.55)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Subsample x.
Args:
x:
Its shape is (N, T, idim).
Returns:
Return a tensor of shape (N, ((T-1)//2 - 1)//2, odim)
"""
# On entry, x is (N, T, idim)
x = x.unsqueeze(1) # (N, T, idim) -> (N, 1, T, idim) i.e., (N, C, H, W)
x = self.conv(x)
# Now x is of shape (N, odim, ((T-1)//2 - 1)//2, ((idim-1)//2 - 1)//2)
b, c, t, f = x.size()
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
# Now x is of shape (N, ((T-1)//2 - 1))//2, odim)
x = self.out_norm(x)
x = self.out_balancer(x)
return x
if __name__ == '__main__':

View File

@ -32,34 +32,43 @@ class Conv2dSubsampling(nn.Module):
https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/subsampling.py # noqa
"""
def __init__(self, idim: int, odim: int) -> None:
def __init__(self, in_channels: int,
out_channels: int,
layer1_channels: int = 64,
layer2_channels: int = 128) -> None:
"""
Args:
idim:
Input dim. The input shape is (N, T, idim).
Caution: It requires: T >=7, idim >=7
odim:
Output dim. The output shape is (N, ((T-1)//2 - 1)//2, odim)
in_channels:
Number of channels in. The input shape is (N, T, in_channels).
Caution: It requires: T >=7, in_channels >=7
out_channels
Output dim. The output shape is (N, ((T-1)//2 - 1)//2, out_channels)
layer1_channels:
Number of channels in layer1
layer1_channels:
Number of channels in layer2
"""
assert idim >= 7
assert in_channels >= 7
super().__init__()
self.conv = nn.Sequential(
ScaledConv2d(
in_channels=1, out_channels=odim, kernel_size=3, stride=2
in_channels=1, out_channels=layer1_channels,
kernel_size=3, stride=2
),
ActivationBalancer(channel_dim=1),
DoubleSwish(),
ScaledConv2d(
in_channels=odim, out_channels=odim, kernel_size=3, stride=2
in_channels=layer1_channels, out_channels=layer2_channels,
kernel_size=3, stride=2
),
ActivationBalancer(channel_dim=1),
DoubleSwish(),
)
self.out = ScaledLinear(odim * (((idim - 1) // 2 - 1) // 2), odim)
self.out = ScaledLinear(layer2_channels * (((in_channels - 1) // 2 - 1) // 2), out_channels)
# set learn_eps=False because out_norm is preceded by `out`, and `out`
# itself has learned scale, so the extra degree of freedom is not
# needed.
self.out_norm = BasicNorm(odim, learn_eps=False)
self.out_norm = BasicNorm(out_channels, learn_eps=False)
# constrain median of output to be close to zero.
self.out_balancer = ActivationBalancer(channel_dim=-1,
min_positive=0.45,
@ -86,99 +95,3 @@ class Conv2dSubsampling(nn.Module):
x = self.out_norm(x)
x = self.out_balancer(x)
return x
class VggSubsampling(nn.Module):
"""Trying to follow the setup described in the following paper:
https://arxiv.org/pdf/1910.09799.pdf
This paper is not 100% explicit so I am guessing to some extent,
and trying to compare with other VGG implementations.
Convert an input of shape (N, T, idim) to an output
with shape (N, T', odim), where
T' = ((T-1)//2 - 1)//2, which approximates T' = T//4
"""
def __init__(self, idim: int, odim: int) -> None:
"""Construct a VggSubsampling object.
This uses 2 VGG blocks with 2 Conv2d layers each,
subsampling its input by a factor of 4 in the time dimensions.
Args:
idim:
Input dim. The input shape is (N, T, idim).
Caution: It requires: T >=7, idim >=7
odim:
Output dim. The output shape is (N, ((T-1)//2 - 1)//2, odim)
"""
super().__init__()
cur_channels = 1
layers = []
block_dims = [32, 64]
# The decision to use padding=1 for the 1st convolution, then padding=0
# for the 2nd and for the max-pooling, and ceil_mode=True, was driven by
# a back-compatibility concern so that the number of frames at the
# output would be equal to:
# (((T-1)//2)-1)//2.
# We can consider changing this by using padding=1 on the
# 2nd convolution, so the num-frames at the output would be T//4.
for block_dim in block_dims:
layers.append(
torch.nn.Conv2d(
in_channels=cur_channels,
out_channels=block_dim,
kernel_size=3,
padding=1,
stride=1,
)
)
layers.append(torch.nn.ReLU())
layers.append(
torch.nn.Conv2d(
in_channels=block_dim,
out_channels=block_dim,
kernel_size=3,
padding=0,
stride=1,
)
)
layers.append(
torch.nn.MaxPool2d(
kernel_size=2, stride=2, padding=0, ceil_mode=True
)
)
cur_channels = block_dim
self.layers = nn.Sequential(*layers)
self.out = nn.Linear(
block_dims[-1] * (((idim - 1) // 2 - 1) // 2), odim
)
self.out_norm = BasicNorm(odim, learn_eps=False)
# constrain median of output to be close to zero.
self.out_balancer = ActivationBalancer(channel_dim=-1,
min_positive=0.45,
max_positive=0.55)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Subsample x.
Args:
x:
Its shape is (N, T, idim).
Returns:
Return a tensor of shape (N, ((T-1)//2 - 1)//2, odim)
"""
x = x.unsqueeze(1)
x = self.layers(x)
b, c, t, f = x.size()
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
x = self.out_norm(x)
x = self.out_balancer(x)
return x

View File

@ -291,7 +291,6 @@ def get_params() -> AttributeDict:
"nhead": 8,
"dim_feedforward": 2048,
"num_encoder_layers": 12,
"vgg_frontend": False,
# parameters for decoder
"embedding_dim": 512,
# parameters for Noam
@ -314,7 +313,6 @@ def get_encoder_model(params: AttributeDict) -> nn.Module:
nhead=params.nhead,
dim_feedforward=params.dim_feedforward,
num_encoder_layers=params.num_encoder_layers,
vgg_frontend=params.vgg_frontend,
)
return encoder

View File

@ -78,9 +78,6 @@ class Transformer(EncoderInterface):
# That is, it does two things simultaneously:
# (1) subsampling: T -> T//subsampling_factor
# (2) embedding: num_features -> d_model
if vgg_frontend:
self.encoder_embed = VggSubsampling(num_features, d_model)
else:
self.encoder_embed = Conv2dSubsampling(num_features, d_model)
self.encoder_pos = PositionalEncoding(d_model, dropout)