mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-08 17:42:21 +00:00
Refactoring and simplifying conformer and frontend
This commit is contained in:
parent
57f943b25c
commit
11124b03ea
@ -16,6 +16,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import copy
|
||||
from encoder_interface import EncoderInterface
|
||||
import math
|
||||
import warnings
|
||||
from typing import Optional, Tuple, Sequence
|
||||
@ -23,12 +24,11 @@ from scaling import DoubleSwish, ActivationBalancer, BasicNorm, ScaledLinear, Sc
|
||||
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
from transformer import Transformer
|
||||
|
||||
from icefall.utils import make_pad_mask
|
||||
|
||||
|
||||
class Conformer(Transformer):
|
||||
class Conformer(EncoderInterface):
|
||||
"""
|
||||
Args:
|
||||
num_features (int): Number of input features
|
||||
@ -40,7 +40,6 @@ class Conformer(Transformer):
|
||||
num_encoder_layers (int): number of encoder layers
|
||||
dropout (float): dropout rate
|
||||
cnn_module_kernel (int): Kernel size of convolution module
|
||||
normalize_before (bool): whether to use layer_norm before the first block.
|
||||
vgg_frontend (bool): whether to use vgg frontend.
|
||||
"""
|
||||
|
||||
@ -55,22 +54,22 @@ class Conformer(Transformer):
|
||||
num_encoder_layers: int = 12,
|
||||
dropout: float = 0.1,
|
||||
cnn_module_kernel: int = 31,
|
||||
normalize_before: bool = True,
|
||||
vgg_frontend: bool = False,
|
||||
aux_layer_period: int = 3
|
||||
) -> None:
|
||||
super(Conformer, self).__init__(
|
||||
num_features=num_features,
|
||||
output_dim=output_dim,
|
||||
subsampling_factor=subsampling_factor,
|
||||
d_model=d_model,
|
||||
nhead=nhead,
|
||||
dim_feedforward=dim_feedforward,
|
||||
num_encoder_layers=num_encoder_layers,
|
||||
dropout=dropout,
|
||||
normalize_before=normalize_before,
|
||||
vgg_frontend=vgg_frontend,
|
||||
)
|
||||
super(Conformer, self).__init__()
|
||||
|
||||
self.num_features = num_features
|
||||
self.output_dim = output_dim
|
||||
self.subsampling_factor = subsampling_factor
|
||||
if subsampling_factor != 4:
|
||||
raise NotImplementedError("Support only 'subsampling_factor=4'.")
|
||||
|
||||
# self.encoder_embed converts the input of shape (N, T, num_features)
|
||||
# to the shape (N, T//subsampling_factor, d_model).
|
||||
# That is, it does two things simultaneously:
|
||||
# (1) subsampling: T -> T//subsampling_factor
|
||||
# (2) embedding: num_features -> d_model
|
||||
self.encoder_embed = Conv2dSubsampling(num_features, 128, d_model)
|
||||
|
||||
self.encoder_pos = RelPositionalEncoding(d_model, dropout)
|
||||
|
||||
@ -80,11 +79,13 @@ class Conformer(Transformer):
|
||||
dim_feedforward,
|
||||
dropout,
|
||||
cnn_module_kernel,
|
||||
normalize_before,
|
||||
)
|
||||
self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers,
|
||||
aux_layers=list(range(0, num_encoder_layers-1, aux_layer_period)))
|
||||
self.normalize_before = normalize_before
|
||||
|
||||
self.encoder_output_layer = nn.Sequential(
|
||||
nn.Dropout(p=dropout), ScaledLinear(d_model, output_dim)
|
||||
)
|
||||
|
||||
|
||||
def forward(
|
||||
@ -136,7 +137,6 @@ class ConformerEncoderLayer(nn.Module):
|
||||
dim_feedforward: the dimension of the feedforward network model (default=2048).
|
||||
dropout: the dropout value (default=0.1).
|
||||
cnn_module_kernel (int): Kernel size of convolution module.
|
||||
normalize_before: whether to use layer_norm before the first block.
|
||||
|
||||
Examples::
|
||||
>>> encoder_layer = ConformerEncoderLayer(d_model=512, nhead=8)
|
||||
@ -152,7 +152,6 @@ class ConformerEncoderLayer(nn.Module):
|
||||
dim_feedforward: int = 2048,
|
||||
dropout: float = 0.1,
|
||||
cnn_module_kernel: int = 31,
|
||||
normalize_before: bool = True,
|
||||
) -> None:
|
||||
super(ConformerEncoderLayer, self).__init__()
|
||||
self.d_model = d_model
|
||||
@ -942,6 +941,80 @@ class Identity(torch.nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
class Conv2dSubsampling(nn.Module):
|
||||
"""Convolutional 2D subsampling (to 1/4 length).
|
||||
|
||||
Convert an input of shape (N, T, idim) to an output
|
||||
with shape (N, T', odim), where
|
||||
T' = ((T-1)//2 - 1)//2, which approximates T' == T//4
|
||||
|
||||
It is based on
|
||||
https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/subsampling.py # noqa
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels: int,
|
||||
out_channels: int,
|
||||
layer1_channels: int = 64,
|
||||
layer2_channels: int = 128) -> None:
|
||||
"""
|
||||
Args:
|
||||
in_channels:
|
||||
Number of channels in. The input shape is (N, T, in_channels).
|
||||
Caution: It requires: T >=7, in_channels >=7
|
||||
out_channels
|
||||
Output dim. The output shape is (N, ((T-1)//2 - 1)//2, out_channels)
|
||||
layer1_channels:
|
||||
Number of channels in layer1
|
||||
layer1_channels:
|
||||
Number of channels in layer2
|
||||
"""
|
||||
assert in_channels >= 7
|
||||
super().__init__()
|
||||
self.conv = nn.Sequential(
|
||||
ScaledConv2d(
|
||||
in_channels=1, out_channels=layer1_channels,
|
||||
kernel_size=3, stride=2
|
||||
),
|
||||
ActivationBalancer(channel_dim=1),
|
||||
DoubleSwish(),
|
||||
ScaledConv2d(
|
||||
in_channels=layer1_channels, out_channels=layer2_channels,
|
||||
kernel_size=3, stride=2
|
||||
),
|
||||
ActivationBalancer(channel_dim=1),
|
||||
DoubleSwish(),
|
||||
)
|
||||
self.out = ScaledLinear(layer2_channels * (((in_channels - 1) // 2 - 1) // 2), out_channels)
|
||||
# set learn_eps=False because out_norm is preceded by `out`, and `out`
|
||||
# itself has learned scale, so the extra degree of freedom is not
|
||||
# needed.
|
||||
self.out_norm = BasicNorm(out_channels, learn_eps=False)
|
||||
# constrain median of output to be close to zero.
|
||||
self.out_balancer = ActivationBalancer(channel_dim=-1,
|
||||
min_positive=0.45,
|
||||
max_positive=0.55)
|
||||
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Subsample x.
|
||||
|
||||
Args:
|
||||
x:
|
||||
Its shape is (N, T, idim).
|
||||
|
||||
Returns:
|
||||
Return a tensor of shape (N, ((T-1)//2 - 1)//2, odim)
|
||||
"""
|
||||
# On entry, x is (N, T, idim)
|
||||
x = x.unsqueeze(1) # (N, T, idim) -> (N, 1, T, idim) i.e., (N, C, H, W)
|
||||
x = self.conv(x)
|
||||
# Now x is of shape (N, odim, ((T-1)//2 - 1)//2, ((idim-1)//2 - 1)//2)
|
||||
b, c, t, f = x.size()
|
||||
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
|
||||
# Now x is of shape (N, ((T-1)//2 - 1))//2, odim)
|
||||
x = self.out_norm(x)
|
||||
x = self.out_balancer(x)
|
||||
return x
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@ -32,34 +32,43 @@ class Conv2dSubsampling(nn.Module):
|
||||
https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/subsampling.py # noqa
|
||||
"""
|
||||
|
||||
def __init__(self, idim: int, odim: int) -> None:
|
||||
def __init__(self, in_channels: int,
|
||||
out_channels: int,
|
||||
layer1_channels: int = 64,
|
||||
layer2_channels: int = 128) -> None:
|
||||
"""
|
||||
Args:
|
||||
idim:
|
||||
Input dim. The input shape is (N, T, idim).
|
||||
Caution: It requires: T >=7, idim >=7
|
||||
odim:
|
||||
Output dim. The output shape is (N, ((T-1)//2 - 1)//2, odim)
|
||||
in_channels:
|
||||
Number of channels in. The input shape is (N, T, in_channels).
|
||||
Caution: It requires: T >=7, in_channels >=7
|
||||
out_channels
|
||||
Output dim. The output shape is (N, ((T-1)//2 - 1)//2, out_channels)
|
||||
layer1_channels:
|
||||
Number of channels in layer1
|
||||
layer1_channels:
|
||||
Number of channels in layer2
|
||||
"""
|
||||
assert idim >= 7
|
||||
assert in_channels >= 7
|
||||
super().__init__()
|
||||
self.conv = nn.Sequential(
|
||||
ScaledConv2d(
|
||||
in_channels=1, out_channels=odim, kernel_size=3, stride=2
|
||||
in_channels=1, out_channels=layer1_channels,
|
||||
kernel_size=3, stride=2
|
||||
),
|
||||
ActivationBalancer(channel_dim=1),
|
||||
DoubleSwish(),
|
||||
ScaledConv2d(
|
||||
in_channels=odim, out_channels=odim, kernel_size=3, stride=2
|
||||
in_channels=layer1_channels, out_channels=layer2_channels,
|
||||
kernel_size=3, stride=2
|
||||
),
|
||||
ActivationBalancer(channel_dim=1),
|
||||
DoubleSwish(),
|
||||
)
|
||||
self.out = ScaledLinear(odim * (((idim - 1) // 2 - 1) // 2), odim)
|
||||
self.out = ScaledLinear(layer2_channels * (((in_channels - 1) // 2 - 1) // 2), out_channels)
|
||||
# set learn_eps=False because out_norm is preceded by `out`, and `out`
|
||||
# itself has learned scale, so the extra degree of freedom is not
|
||||
# needed.
|
||||
self.out_norm = BasicNorm(odim, learn_eps=False)
|
||||
self.out_norm = BasicNorm(out_channels, learn_eps=False)
|
||||
# constrain median of output to be close to zero.
|
||||
self.out_balancer = ActivationBalancer(channel_dim=-1,
|
||||
min_positive=0.45,
|
||||
@ -86,99 +95,3 @@ class Conv2dSubsampling(nn.Module):
|
||||
x = self.out_norm(x)
|
||||
x = self.out_balancer(x)
|
||||
return x
|
||||
|
||||
|
||||
class VggSubsampling(nn.Module):
|
||||
"""Trying to follow the setup described in the following paper:
|
||||
https://arxiv.org/pdf/1910.09799.pdf
|
||||
|
||||
This paper is not 100% explicit so I am guessing to some extent,
|
||||
and trying to compare with other VGG implementations.
|
||||
|
||||
Convert an input of shape (N, T, idim) to an output
|
||||
with shape (N, T', odim), where
|
||||
T' = ((T-1)//2 - 1)//2, which approximates T' = T//4
|
||||
"""
|
||||
|
||||
def __init__(self, idim: int, odim: int) -> None:
|
||||
"""Construct a VggSubsampling object.
|
||||
|
||||
This uses 2 VGG blocks with 2 Conv2d layers each,
|
||||
subsampling its input by a factor of 4 in the time dimensions.
|
||||
|
||||
Args:
|
||||
idim:
|
||||
Input dim. The input shape is (N, T, idim).
|
||||
Caution: It requires: T >=7, idim >=7
|
||||
odim:
|
||||
Output dim. The output shape is (N, ((T-1)//2 - 1)//2, odim)
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
cur_channels = 1
|
||||
layers = []
|
||||
block_dims = [32, 64]
|
||||
|
||||
# The decision to use padding=1 for the 1st convolution, then padding=0
|
||||
# for the 2nd and for the max-pooling, and ceil_mode=True, was driven by
|
||||
# a back-compatibility concern so that the number of frames at the
|
||||
# output would be equal to:
|
||||
# (((T-1)//2)-1)//2.
|
||||
# We can consider changing this by using padding=1 on the
|
||||
# 2nd convolution, so the num-frames at the output would be T//4.
|
||||
for block_dim in block_dims:
|
||||
layers.append(
|
||||
torch.nn.Conv2d(
|
||||
in_channels=cur_channels,
|
||||
out_channels=block_dim,
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
stride=1,
|
||||
)
|
||||
)
|
||||
layers.append(torch.nn.ReLU())
|
||||
layers.append(
|
||||
torch.nn.Conv2d(
|
||||
in_channels=block_dim,
|
||||
out_channels=block_dim,
|
||||
kernel_size=3,
|
||||
padding=0,
|
||||
stride=1,
|
||||
)
|
||||
)
|
||||
layers.append(
|
||||
torch.nn.MaxPool2d(
|
||||
kernel_size=2, stride=2, padding=0, ceil_mode=True
|
||||
)
|
||||
)
|
||||
cur_channels = block_dim
|
||||
|
||||
self.layers = nn.Sequential(*layers)
|
||||
|
||||
self.out = nn.Linear(
|
||||
block_dims[-1] * (((idim - 1) // 2 - 1) // 2), odim
|
||||
)
|
||||
self.out_norm = BasicNorm(odim, learn_eps=False)
|
||||
# constrain median of output to be close to zero.
|
||||
self.out_balancer = ActivationBalancer(channel_dim=-1,
|
||||
min_positive=0.45,
|
||||
max_positive=0.55)
|
||||
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Subsample x.
|
||||
|
||||
Args:
|
||||
x:
|
||||
Its shape is (N, T, idim).
|
||||
|
||||
Returns:
|
||||
Return a tensor of shape (N, ((T-1)//2 - 1)//2, odim)
|
||||
"""
|
||||
x = x.unsqueeze(1)
|
||||
x = self.layers(x)
|
||||
b, c, t, f = x.size()
|
||||
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
|
||||
x = self.out_norm(x)
|
||||
x = self.out_balancer(x)
|
||||
return x
|
||||
|
@ -291,7 +291,6 @@ def get_params() -> AttributeDict:
|
||||
"nhead": 8,
|
||||
"dim_feedforward": 2048,
|
||||
"num_encoder_layers": 12,
|
||||
"vgg_frontend": False,
|
||||
# parameters for decoder
|
||||
"embedding_dim": 512,
|
||||
# parameters for Noam
|
||||
@ -314,7 +313,6 @@ def get_encoder_model(params: AttributeDict) -> nn.Module:
|
||||
nhead=params.nhead,
|
||||
dim_feedforward=params.dim_feedforward,
|
||||
num_encoder_layers=params.num_encoder_layers,
|
||||
vgg_frontend=params.vgg_frontend,
|
||||
)
|
||||
return encoder
|
||||
|
||||
|
@ -78,10 +78,7 @@ class Transformer(EncoderInterface):
|
||||
# That is, it does two things simultaneously:
|
||||
# (1) subsampling: T -> T//subsampling_factor
|
||||
# (2) embedding: num_features -> d_model
|
||||
if vgg_frontend:
|
||||
self.encoder_embed = VggSubsampling(num_features, d_model)
|
||||
else:
|
||||
self.encoder_embed = Conv2dSubsampling(num_features, d_model)
|
||||
self.encoder_embed = Conv2dSubsampling(num_features, d_model)
|
||||
|
||||
self.encoder_pos = PositionalEncoding(d_model, dropout)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user