Fix duplicate Swish; replace norm+swish with swish+exp-scale in convolution module

This commit is contained in:
Daniel Povey 2022-03-04 15:50:51 +08:00
parent 7e88999641
commit 9cc5999829
2 changed files with 4 additions and 7 deletions

View File

@ -163,7 +163,6 @@ class ConformerEncoderLayer(nn.Module):
self.feed_forward_macaron = nn.Sequential( self.feed_forward_macaron = nn.Sequential(
nn.Linear(d_model, dim_feedforward), nn.Linear(d_model, dim_feedforward),
Swish(),
ExpScaleSwish(dim_feedforward, speed=50.0), ExpScaleSwish(dim_feedforward, speed=50.0),
nn.Dropout(dropout), nn.Dropout(dropout),
nn.Linear(dim_feedforward, d_model), nn.Linear(dim_feedforward, d_model),
@ -874,7 +873,9 @@ class ConvolutionModule(nn.Module):
groups=channels, groups=channels,
bias=bias, bias=bias,
) )
self.norm = nn.LayerNorm(channels) # shape: (channels, 1), broadcasts with (batch, channel, time).
self.activation = ExpScaleSwish(channels, 1, speed=50.0)
self.pointwise_conv2 = nn.Conv1d( self.pointwise_conv2 = nn.Conv1d(
channels, channels,
channels, channels,
@ -883,7 +884,6 @@ class ConvolutionModule(nn.Module):
padding=0, padding=0,
bias=bias, bias=bias,
) )
self.activation = Swish()
def forward(self, x: Tensor) -> Tensor: def forward(self, x: Tensor) -> Tensor:
"""Compute convolution module. """Compute convolution module.
@ -905,9 +905,6 @@ class ConvolutionModule(nn.Module):
# 1D Depthwise Conv # 1D Depthwise Conv
x = self.depthwise_conv(x) x = self.depthwise_conv(x)
# x is (batch, channels, time) # x is (batch, channels, time)
x = x.permute(0, 2, 1)
x = self.norm(x)
x = x.permute(0, 2, 1)
x = self.activation(x) x = self.activation(x)

View File

@ -110,7 +110,7 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--exp-dir", "--exp-dir",
type=str, type=str,
default="transducer_stateless/specaugmod_baseline_randcombine1_expscale4", default="transducer_stateless/specaugmod_baseline_randcombine1_expscale5",
help="""The experiment dir. help="""The experiment dir.
It specifies the directory where all training related It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved files, e.g., checkpoints, log, etc, are saved