Restore ConvolutionModule to state before changes; change all Swish,Swish(Swish) to SwishOffset.

This commit is contained in:
Daniel Povey 2022-03-06 11:55:02 +08:00
parent 8a8b81cd18
commit a37d98463a
3 changed files with 20 additions and 9 deletions

View File

@ -212,9 +212,8 @@ class ExpScale(torch.nn.Module):
def _exp_scale_swish(x: Tensor, scale: Tensor, speed: float) -> Tensor: def _exp_scale_swish(x: Tensor, scale: Tensor, speed: float) -> Tensor:
# double-swish! # double-swish, implemented/approximated as offset-swish
x = (x * torch.sigmoid(x)) x = (x * torch.sigmoid(x - 1.0))
x = (x * torch.sigmoid(x))
x = x * (scale * speed).exp() x = x * (scale * speed).exp()
return x return x

View File

@ -877,10 +877,10 @@ class ConvolutionModule(nn.Module):
groups=channels, groups=channels,
bias=bias, bias=bias,
) )
self.balancer = DerivBalancer(channel_dim=1, threshold=0.05,
max_factor=0.025) self.norm = nn.LayerNorm(channels)
# shape: (channels, 1), broadcasts with (batch, channel, time). # shape: (channels, 1), broadcasts with (batch, channel, time).
self.activation = ExpScaleSwish(channels, 1, speed=20.0) self.activation = SwishOffset()
self.pointwise_conv2 = nn.Conv1d( self.pointwise_conv2 = nn.Conv1d(
channels, channels,
@ -911,8 +911,10 @@ class ConvolutionModule(nn.Module):
# 1D Depthwise Conv # 1D Depthwise Conv
x = self.depthwise_conv(x) x = self.depthwise_conv(x)
# x is (batch, channels, time) # x is (batch, channels, time)
x = x.permute(0, 2, 1)
x = self.norm(x)
x = x.permute(0, 2, 1)
x = self.balancer(x)
x = self.activation(x) x = self.activation(x)
x = self.pointwise_conv2(x) # (batch, channel, time) x = self.pointwise_conv2(x) # (batch, channel, time)
@ -927,6 +929,16 @@ class Swish(torch.nn.Module):
"""Return Swich activation function.""" """Return Swich activation function."""
return x * torch.sigmoid(x) return x * torch.sigmoid(x)
class SwishOffset(torch.nn.Module):
"""Construct an SwishOffset object."""
def __init__(self, offset: float = -1.0) -> None:
super(SwishOffset, self).__init__()
self.offset = offset
def forward(self, x: Tensor) -> Tensor:
"""Return Swich activation function."""
return x * torch.sigmoid(x + self.offset)
def identity(x): def identity(x):
return x return x

View File

@ -110,7 +110,7 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--exp-dir", "--exp-dir",
type=str, type=str,
default="transducer_stateless/specaugmod_baseline_randcombine1_expscale5_brelu2swish2", default="transducer_stateless/specaugmod_baseline_randcombine1_expscale3_brelu2swish2",
help="""The experiment dir. help="""The experiment dir.
It specifies the directory where all training related It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved files, e.g., checkpoints, log, etc, are saved