diff --git a/egs/librispeech/ASR/conformer_ctc/subsampling.py b/egs/librispeech/ASR/conformer_ctc/subsampling.py index 622495f21..29621bf52 100644 --- a/egs/librispeech/ASR/conformer_ctc/subsampling.py +++ b/egs/librispeech/ASR/conformer_ctc/subsampling.py @@ -58,6 +58,7 @@ class Conv2dSubsampling(nn.Module): ExpScaleRelu(odim, 1, 1, speed=20.0), ) self.out = nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim) + self.out_norm = BasicNorm(odim) def forward(self, x: torch.Tensor) -> torch.Tensor: """Subsample x. @@ -76,7 +77,7 @@ class Conv2dSubsampling(nn.Module): b, c, t, f = x.size() x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f)) # Now x is of shape (N, ((T-1)//2 - 1))//2, odim) - x = x * 0.1 + x = self.out_norm(x) return x @@ -200,9 +201,11 @@ class PeLU(torch.nn.Module): return PeLUFunction.apply(x, self.cutoff, self.alpha) class ExpScale(torch.nn.Module): - def __init__(self, *shape, speed: float = 1.0): + def __init__(self, *shape, speed: float = 1.0, initial_scale: float = 1.0): super(ExpScale, self).__init__() - self.scale = nn.Parameter(torch.zeros(*shape)) + scale = torch.tensor(initial_scale) + scale = scale.log() / speed + self.scale = nn.Parameter(scale.detach()) self.speed = speed def forward(self, x: Tensor) -> Tensor: diff --git a/egs/librispeech/ASR/transducer_stateless/conformer.py b/egs/librispeech/ASR/transducer_stateless/conformer.py index 62d9f382f..acaf064b3 100644 --- a/egs/librispeech/ASR/transducer_stateless/conformer.py +++ b/egs/librispeech/ASR/transducer_stateless/conformer.py @@ -19,7 +19,7 @@ import copy import math import warnings from typing import Optional, Tuple, Sequence -from subsampling import PeLU, ExpScale, ExpScaleSwish, ExpScaleRelu, DerivBalancer +from subsampling import PeLU, ExpScale, ExpScaleSwish, ExpScaleRelu, DerivBalancer, BasicNorm import torch from torch import Tensor, nn @@ -150,6 +150,8 @@ class ConformerEncoderLayer(nn.Module): normalize_before: bool = True, ) -> None: super(ConformerEncoderLayer, self).__init__() + self.d_model = d_model + self.self_attn = RelPositionMultiheadAttention( d_model, nhead, dropout=0.0 ) @@ -174,22 +176,15 @@ class ConformerEncoderLayer(nn.Module): self.conv_module = ConvolutionModule(d_model, cnn_module_kernel) - self.norm_ff_macaron = nn.LayerNorm( - d_model - ) # for the macaron style FNN module - self.norm_ff = nn.LayerNorm(d_model) # for the FNN module - self.norm_mha = nn.LayerNorm(d_model) # for the MHA module + self.scale_mha = ExpScale(1, speed=10.0, initial_scale=0.2) + self.scale_conv = ExpScale(1, speed=10.0, initial_scale=0.5) + self.scale_ff = ExpScale(1, speed=10.0, initial_scale=0.5) + self.scale_ff_macaron = ExpScale(1, speed=10.0, initial_scale=0.5) - self.ff_scale = 0.5 - - self.norm_conv = nn.LayerNorm(d_model) # for the CNN module - self.norm_final = nn.LayerNorm( - d_model - ) # for the final output of the block + self.norm_final = BasicNorm(d_model) self.dropout = nn.Dropout(dropout) - self.normalize_before = normalize_before def forward( self, @@ -217,18 +212,15 @@ class ConformerEncoderLayer(nn.Module): # macaron style feed forward module residual = src - if self.normalize_before: - src = self.norm_ff_macaron(src) - src = residual + self.ff_scale * self.dropout( - self.feed_forward_macaron(src) - ) - if not self.normalize_before: - src = self.norm_ff_macaron(src) + + + src = src + self.dropout(self.feed_forward_macaron( + self.scale_ff_macaron(src))) + # multi-headed self-attention module residual = src - if self.normalize_before: - src = self.norm_mha(src) + src = self.scale_mha(src) src_att = self.self_attn( src, src, @@ -238,27 +230,14 @@ class ConformerEncoderLayer(nn.Module): key_padding_mask=src_key_padding_mask, )[0] src = residual + self.dropout(src_att) - if not self.normalize_before: - src = self.norm_mha(src) # convolution module - residual = src - if self.normalize_before: - src = self.norm_conv(src) - src = residual + self.dropout(self.conv_module(src)) - if not self.normalize_before: - src = self.norm_conv(src) + src = residual + self.dropout(self.conv_module(self.scale_conv(src))) # feed forward module - residual = src - if self.normalize_before: - src = self.norm_ff(src) - src = residual + self.ff_scale * self.dropout(self.feed_forward(src)) - if not self.normalize_before: - src = self.norm_ff(src) + src = src + self.dropout(self.feed_forward(self.scale_ff(src))) - if self.normalize_before: - src = self.norm_final(src) + src = self.norm_final(src) return src @@ -288,7 +267,7 @@ class ConformerEncoder(nn.Module): self.aux_layers = set(aux_layers + [num_layers - 1]) assert num_layers - 1 not in aux_layers self.num_layers = num_layers - num_channels = encoder_layer.norm_final.weight.numel() + num_channels = encoder_layer.d_model self.combiner = RandomCombine(num_inputs=len(self.aux_layers), num_channels=num_channels, final_weight=0.5,