mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-08 17:42:21 +00:00
Refactoring and simplifying conformer and frontend
This commit is contained in:
parent
57f943b25c
commit
11124b03ea
@ -16,6 +16,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import copy
|
import copy
|
||||||
|
from encoder_interface import EncoderInterface
|
||||||
import math
|
import math
|
||||||
import warnings
|
import warnings
|
||||||
from typing import Optional, Tuple, Sequence
|
from typing import Optional, Tuple, Sequence
|
||||||
@ -23,12 +24,11 @@ from scaling import DoubleSwish, ActivationBalancer, BasicNorm, ScaledLinear, Sc
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
from transformer import Transformer
|
|
||||||
|
|
||||||
from icefall.utils import make_pad_mask
|
from icefall.utils import make_pad_mask
|
||||||
|
|
||||||
|
|
||||||
class Conformer(Transformer):
|
class Conformer(EncoderInterface):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
num_features (int): Number of input features
|
num_features (int): Number of input features
|
||||||
@ -40,7 +40,6 @@ class Conformer(Transformer):
|
|||||||
num_encoder_layers (int): number of encoder layers
|
num_encoder_layers (int): number of encoder layers
|
||||||
dropout (float): dropout rate
|
dropout (float): dropout rate
|
||||||
cnn_module_kernel (int): Kernel size of convolution module
|
cnn_module_kernel (int): Kernel size of convolution module
|
||||||
normalize_before (bool): whether to use layer_norm before the first block.
|
|
||||||
vgg_frontend (bool): whether to use vgg frontend.
|
vgg_frontend (bool): whether to use vgg frontend.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@ -55,22 +54,22 @@ class Conformer(Transformer):
|
|||||||
num_encoder_layers: int = 12,
|
num_encoder_layers: int = 12,
|
||||||
dropout: float = 0.1,
|
dropout: float = 0.1,
|
||||||
cnn_module_kernel: int = 31,
|
cnn_module_kernel: int = 31,
|
||||||
normalize_before: bool = True,
|
|
||||||
vgg_frontend: bool = False,
|
|
||||||
aux_layer_period: int = 3
|
aux_layer_period: int = 3
|
||||||
) -> None:
|
) -> None:
|
||||||
super(Conformer, self).__init__(
|
super(Conformer, self).__init__()
|
||||||
num_features=num_features,
|
|
||||||
output_dim=output_dim,
|
self.num_features = num_features
|
||||||
subsampling_factor=subsampling_factor,
|
self.output_dim = output_dim
|
||||||
d_model=d_model,
|
self.subsampling_factor = subsampling_factor
|
||||||
nhead=nhead,
|
if subsampling_factor != 4:
|
||||||
dim_feedforward=dim_feedforward,
|
raise NotImplementedError("Support only 'subsampling_factor=4'.")
|
||||||
num_encoder_layers=num_encoder_layers,
|
|
||||||
dropout=dropout,
|
# self.encoder_embed converts the input of shape (N, T, num_features)
|
||||||
normalize_before=normalize_before,
|
# to the shape (N, T//subsampling_factor, d_model).
|
||||||
vgg_frontend=vgg_frontend,
|
# That is, it does two things simultaneously:
|
||||||
)
|
# (1) subsampling: T -> T//subsampling_factor
|
||||||
|
# (2) embedding: num_features -> d_model
|
||||||
|
self.encoder_embed = Conv2dSubsampling(num_features, 128, d_model)
|
||||||
|
|
||||||
self.encoder_pos = RelPositionalEncoding(d_model, dropout)
|
self.encoder_pos = RelPositionalEncoding(d_model, dropout)
|
||||||
|
|
||||||
@ -80,11 +79,13 @@ class Conformer(Transformer):
|
|||||||
dim_feedforward,
|
dim_feedforward,
|
||||||
dropout,
|
dropout,
|
||||||
cnn_module_kernel,
|
cnn_module_kernel,
|
||||||
normalize_before,
|
|
||||||
)
|
)
|
||||||
self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers,
|
self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers,
|
||||||
aux_layers=list(range(0, num_encoder_layers-1, aux_layer_period)))
|
aux_layers=list(range(0, num_encoder_layers-1, aux_layer_period)))
|
||||||
self.normalize_before = normalize_before
|
|
||||||
|
self.encoder_output_layer = nn.Sequential(
|
||||||
|
nn.Dropout(p=dropout), ScaledLinear(d_model, output_dim)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
@ -136,7 +137,6 @@ class ConformerEncoderLayer(nn.Module):
|
|||||||
dim_feedforward: the dimension of the feedforward network model (default=2048).
|
dim_feedforward: the dimension of the feedforward network model (default=2048).
|
||||||
dropout: the dropout value (default=0.1).
|
dropout: the dropout value (default=0.1).
|
||||||
cnn_module_kernel (int): Kernel size of convolution module.
|
cnn_module_kernel (int): Kernel size of convolution module.
|
||||||
normalize_before: whether to use layer_norm before the first block.
|
|
||||||
|
|
||||||
Examples::
|
Examples::
|
||||||
>>> encoder_layer = ConformerEncoderLayer(d_model=512, nhead=8)
|
>>> encoder_layer = ConformerEncoderLayer(d_model=512, nhead=8)
|
||||||
@ -152,7 +152,6 @@ class ConformerEncoderLayer(nn.Module):
|
|||||||
dim_feedforward: int = 2048,
|
dim_feedforward: int = 2048,
|
||||||
dropout: float = 0.1,
|
dropout: float = 0.1,
|
||||||
cnn_module_kernel: int = 31,
|
cnn_module_kernel: int = 31,
|
||||||
normalize_before: bool = True,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
super(ConformerEncoderLayer, self).__init__()
|
super(ConformerEncoderLayer, self).__init__()
|
||||||
self.d_model = d_model
|
self.d_model = d_model
|
||||||
@ -942,6 +941,80 @@ class Identity(torch.nn.Module):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class Conv2dSubsampling(nn.Module):
|
||||||
|
"""Convolutional 2D subsampling (to 1/4 length).
|
||||||
|
|
||||||
|
Convert an input of shape (N, T, idim) to an output
|
||||||
|
with shape (N, T', odim), where
|
||||||
|
T' = ((T-1)//2 - 1)//2, which approximates T' == T//4
|
||||||
|
|
||||||
|
It is based on
|
||||||
|
https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/subsampling.py # noqa
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, in_channels: int,
|
||||||
|
out_channels: int,
|
||||||
|
layer1_channels: int = 64,
|
||||||
|
layer2_channels: int = 128) -> None:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
in_channels:
|
||||||
|
Number of channels in. The input shape is (N, T, in_channels).
|
||||||
|
Caution: It requires: T >=7, in_channels >=7
|
||||||
|
out_channels
|
||||||
|
Output dim. The output shape is (N, ((T-1)//2 - 1)//2, out_channels)
|
||||||
|
layer1_channels:
|
||||||
|
Number of channels in layer1
|
||||||
|
layer1_channels:
|
||||||
|
Number of channels in layer2
|
||||||
|
"""
|
||||||
|
assert in_channels >= 7
|
||||||
|
super().__init__()
|
||||||
|
self.conv = nn.Sequential(
|
||||||
|
ScaledConv2d(
|
||||||
|
in_channels=1, out_channels=layer1_channels,
|
||||||
|
kernel_size=3, stride=2
|
||||||
|
),
|
||||||
|
ActivationBalancer(channel_dim=1),
|
||||||
|
DoubleSwish(),
|
||||||
|
ScaledConv2d(
|
||||||
|
in_channels=layer1_channels, out_channels=layer2_channels,
|
||||||
|
kernel_size=3, stride=2
|
||||||
|
),
|
||||||
|
ActivationBalancer(channel_dim=1),
|
||||||
|
DoubleSwish(),
|
||||||
|
)
|
||||||
|
self.out = ScaledLinear(layer2_channels * (((in_channels - 1) // 2 - 1) // 2), out_channels)
|
||||||
|
# set learn_eps=False because out_norm is preceded by `out`, and `out`
|
||||||
|
# itself has learned scale, so the extra degree of freedom is not
|
||||||
|
# needed.
|
||||||
|
self.out_norm = BasicNorm(out_channels, learn_eps=False)
|
||||||
|
# constrain median of output to be close to zero.
|
||||||
|
self.out_balancer = ActivationBalancer(channel_dim=-1,
|
||||||
|
min_positive=0.45,
|
||||||
|
max_positive=0.55)
|
||||||
|
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""Subsample x.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x:
|
||||||
|
Its shape is (N, T, idim).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Return a tensor of shape (N, ((T-1)//2 - 1)//2, odim)
|
||||||
|
"""
|
||||||
|
# On entry, x is (N, T, idim)
|
||||||
|
x = x.unsqueeze(1) # (N, T, idim) -> (N, 1, T, idim) i.e., (N, C, H, W)
|
||||||
|
x = self.conv(x)
|
||||||
|
# Now x is of shape (N, odim, ((T-1)//2 - 1)//2, ((idim-1)//2 - 1)//2)
|
||||||
|
b, c, t, f = x.size()
|
||||||
|
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
|
||||||
|
# Now x is of shape (N, ((T-1)//2 - 1))//2, odim)
|
||||||
|
x = self.out_norm(x)
|
||||||
|
x = self.out_balancer(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
@ -32,34 +32,43 @@ class Conv2dSubsampling(nn.Module):
|
|||||||
https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/subsampling.py # noqa
|
https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/subsampling.py # noqa
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, idim: int, odim: int) -> None:
|
def __init__(self, in_channels: int,
|
||||||
|
out_channels: int,
|
||||||
|
layer1_channels: int = 64,
|
||||||
|
layer2_channels: int = 128) -> None:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
idim:
|
in_channels:
|
||||||
Input dim. The input shape is (N, T, idim).
|
Number of channels in. The input shape is (N, T, in_channels).
|
||||||
Caution: It requires: T >=7, idim >=7
|
Caution: It requires: T >=7, in_channels >=7
|
||||||
odim:
|
out_channels
|
||||||
Output dim. The output shape is (N, ((T-1)//2 - 1)//2, odim)
|
Output dim. The output shape is (N, ((T-1)//2 - 1)//2, out_channels)
|
||||||
|
layer1_channels:
|
||||||
|
Number of channels in layer1
|
||||||
|
layer1_channels:
|
||||||
|
Number of channels in layer2
|
||||||
"""
|
"""
|
||||||
assert idim >= 7
|
assert in_channels >= 7
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.conv = nn.Sequential(
|
self.conv = nn.Sequential(
|
||||||
ScaledConv2d(
|
ScaledConv2d(
|
||||||
in_channels=1, out_channels=odim, kernel_size=3, stride=2
|
in_channels=1, out_channels=layer1_channels,
|
||||||
|
kernel_size=3, stride=2
|
||||||
),
|
),
|
||||||
ActivationBalancer(channel_dim=1),
|
ActivationBalancer(channel_dim=1),
|
||||||
DoubleSwish(),
|
DoubleSwish(),
|
||||||
ScaledConv2d(
|
ScaledConv2d(
|
||||||
in_channels=odim, out_channels=odim, kernel_size=3, stride=2
|
in_channels=layer1_channels, out_channels=layer2_channels,
|
||||||
|
kernel_size=3, stride=2
|
||||||
),
|
),
|
||||||
ActivationBalancer(channel_dim=1),
|
ActivationBalancer(channel_dim=1),
|
||||||
DoubleSwish(),
|
DoubleSwish(),
|
||||||
)
|
)
|
||||||
self.out = ScaledLinear(odim * (((idim - 1) // 2 - 1) // 2), odim)
|
self.out = ScaledLinear(layer2_channels * (((in_channels - 1) // 2 - 1) // 2), out_channels)
|
||||||
# set learn_eps=False because out_norm is preceded by `out`, and `out`
|
# set learn_eps=False because out_norm is preceded by `out`, and `out`
|
||||||
# itself has learned scale, so the extra degree of freedom is not
|
# itself has learned scale, so the extra degree of freedom is not
|
||||||
# needed.
|
# needed.
|
||||||
self.out_norm = BasicNorm(odim, learn_eps=False)
|
self.out_norm = BasicNorm(out_channels, learn_eps=False)
|
||||||
# constrain median of output to be close to zero.
|
# constrain median of output to be close to zero.
|
||||||
self.out_balancer = ActivationBalancer(channel_dim=-1,
|
self.out_balancer = ActivationBalancer(channel_dim=-1,
|
||||||
min_positive=0.45,
|
min_positive=0.45,
|
||||||
@ -86,99 +95,3 @@ class Conv2dSubsampling(nn.Module):
|
|||||||
x = self.out_norm(x)
|
x = self.out_norm(x)
|
||||||
x = self.out_balancer(x)
|
x = self.out_balancer(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
class VggSubsampling(nn.Module):
|
|
||||||
"""Trying to follow the setup described in the following paper:
|
|
||||||
https://arxiv.org/pdf/1910.09799.pdf
|
|
||||||
|
|
||||||
This paper is not 100% explicit so I am guessing to some extent,
|
|
||||||
and trying to compare with other VGG implementations.
|
|
||||||
|
|
||||||
Convert an input of shape (N, T, idim) to an output
|
|
||||||
with shape (N, T', odim), where
|
|
||||||
T' = ((T-1)//2 - 1)//2, which approximates T' = T//4
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, idim: int, odim: int) -> None:
|
|
||||||
"""Construct a VggSubsampling object.
|
|
||||||
|
|
||||||
This uses 2 VGG blocks with 2 Conv2d layers each,
|
|
||||||
subsampling its input by a factor of 4 in the time dimensions.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
idim:
|
|
||||||
Input dim. The input shape is (N, T, idim).
|
|
||||||
Caution: It requires: T >=7, idim >=7
|
|
||||||
odim:
|
|
||||||
Output dim. The output shape is (N, ((T-1)//2 - 1)//2, odim)
|
|
||||||
"""
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
cur_channels = 1
|
|
||||||
layers = []
|
|
||||||
block_dims = [32, 64]
|
|
||||||
|
|
||||||
# The decision to use padding=1 for the 1st convolution, then padding=0
|
|
||||||
# for the 2nd and for the max-pooling, and ceil_mode=True, was driven by
|
|
||||||
# a back-compatibility concern so that the number of frames at the
|
|
||||||
# output would be equal to:
|
|
||||||
# (((T-1)//2)-1)//2.
|
|
||||||
# We can consider changing this by using padding=1 on the
|
|
||||||
# 2nd convolution, so the num-frames at the output would be T//4.
|
|
||||||
for block_dim in block_dims:
|
|
||||||
layers.append(
|
|
||||||
torch.nn.Conv2d(
|
|
||||||
in_channels=cur_channels,
|
|
||||||
out_channels=block_dim,
|
|
||||||
kernel_size=3,
|
|
||||||
padding=1,
|
|
||||||
stride=1,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
layers.append(torch.nn.ReLU())
|
|
||||||
layers.append(
|
|
||||||
torch.nn.Conv2d(
|
|
||||||
in_channels=block_dim,
|
|
||||||
out_channels=block_dim,
|
|
||||||
kernel_size=3,
|
|
||||||
padding=0,
|
|
||||||
stride=1,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
layers.append(
|
|
||||||
torch.nn.MaxPool2d(
|
|
||||||
kernel_size=2, stride=2, padding=0, ceil_mode=True
|
|
||||||
)
|
|
||||||
)
|
|
||||||
cur_channels = block_dim
|
|
||||||
|
|
||||||
self.layers = nn.Sequential(*layers)
|
|
||||||
|
|
||||||
self.out = nn.Linear(
|
|
||||||
block_dims[-1] * (((idim - 1) // 2 - 1) // 2), odim
|
|
||||||
)
|
|
||||||
self.out_norm = BasicNorm(odim, learn_eps=False)
|
|
||||||
# constrain median of output to be close to zero.
|
|
||||||
self.out_balancer = ActivationBalancer(channel_dim=-1,
|
|
||||||
min_positive=0.45,
|
|
||||||
max_positive=0.55)
|
|
||||||
|
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
||||||
"""Subsample x.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
x:
|
|
||||||
Its shape is (N, T, idim).
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Return a tensor of shape (N, ((T-1)//2 - 1)//2, odim)
|
|
||||||
"""
|
|
||||||
x = x.unsqueeze(1)
|
|
||||||
x = self.layers(x)
|
|
||||||
b, c, t, f = x.size()
|
|
||||||
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
|
|
||||||
x = self.out_norm(x)
|
|
||||||
x = self.out_balancer(x)
|
|
||||||
return x
|
|
||||||
|
@ -291,7 +291,6 @@ def get_params() -> AttributeDict:
|
|||||||
"nhead": 8,
|
"nhead": 8,
|
||||||
"dim_feedforward": 2048,
|
"dim_feedforward": 2048,
|
||||||
"num_encoder_layers": 12,
|
"num_encoder_layers": 12,
|
||||||
"vgg_frontend": False,
|
|
||||||
# parameters for decoder
|
# parameters for decoder
|
||||||
"embedding_dim": 512,
|
"embedding_dim": 512,
|
||||||
# parameters for Noam
|
# parameters for Noam
|
||||||
@ -314,7 +313,6 @@ def get_encoder_model(params: AttributeDict) -> nn.Module:
|
|||||||
nhead=params.nhead,
|
nhead=params.nhead,
|
||||||
dim_feedforward=params.dim_feedforward,
|
dim_feedforward=params.dim_feedforward,
|
||||||
num_encoder_layers=params.num_encoder_layers,
|
num_encoder_layers=params.num_encoder_layers,
|
||||||
vgg_frontend=params.vgg_frontend,
|
|
||||||
)
|
)
|
||||||
return encoder
|
return encoder
|
||||||
|
|
||||||
|
@ -78,10 +78,7 @@ class Transformer(EncoderInterface):
|
|||||||
# That is, it does two things simultaneously:
|
# That is, it does two things simultaneously:
|
||||||
# (1) subsampling: T -> T//subsampling_factor
|
# (1) subsampling: T -> T//subsampling_factor
|
||||||
# (2) embedding: num_features -> d_model
|
# (2) embedding: num_features -> d_model
|
||||||
if vgg_frontend:
|
self.encoder_embed = Conv2dSubsampling(num_features, d_model)
|
||||||
self.encoder_embed = VggSubsampling(num_features, d_model)
|
|
||||||
else:
|
|
||||||
self.encoder_embed = Conv2dSubsampling(num_features, d_model)
|
|
||||||
|
|
||||||
self.encoder_pos = PositionalEncoding(d_model, dropout)
|
self.encoder_pos = PositionalEncoding(d_model, dropout)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user