mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-08 17:42:21 +00:00
Restore ConvolutionModule to state before changes; change all Swish,Swish(Swish) to SwishOffset.
This commit is contained in:
parent
8a8b81cd18
commit
a37d98463a
@ -212,9 +212,8 @@ class ExpScale(torch.nn.Module):
|
||||
|
||||
|
||||
def _exp_scale_swish(x: Tensor, scale: Tensor, speed: float) -> Tensor:
|
||||
# double-swish!
|
||||
x = (x * torch.sigmoid(x))
|
||||
x = (x * torch.sigmoid(x))
|
||||
# double-swish, implemented/approximated as offset-swish
|
||||
x = (x * torch.sigmoid(x - 1.0))
|
||||
x = x * (scale * speed).exp()
|
||||
return x
|
||||
|
||||
|
@ -877,10 +877,10 @@ class ConvolutionModule(nn.Module):
|
||||
groups=channels,
|
||||
bias=bias,
|
||||
)
|
||||
self.balancer = DerivBalancer(channel_dim=1, threshold=0.05,
|
||||
max_factor=0.025)
|
||||
# shape: (channels, 1), broadcasts with (batch, channel, time).
|
||||
self.activation = ExpScaleSwish(channels, 1, speed=20.0)
|
||||
|
||||
self.norm = nn.LayerNorm(channels)
|
||||
# shape: (channels, 1), broadcasts with (batch, channel, time).
|
||||
self.activation = SwishOffset()
|
||||
|
||||
self.pointwise_conv2 = nn.Conv1d(
|
||||
channels,
|
||||
@ -911,8 +911,10 @@ class ConvolutionModule(nn.Module):
|
||||
# 1D Depthwise Conv
|
||||
x = self.depthwise_conv(x)
|
||||
# x is (batch, channels, time)
|
||||
x = x.permute(0, 2, 1)
|
||||
x = self.norm(x)
|
||||
x = x.permute(0, 2, 1)
|
||||
|
||||
x = self.balancer(x)
|
||||
x = self.activation(x)
|
||||
|
||||
x = self.pointwise_conv2(x) # (batch, channel, time)
|
||||
@ -927,6 +929,16 @@ class Swish(torch.nn.Module):
|
||||
"""Return Swich activation function."""
|
||||
return x * torch.sigmoid(x)
|
||||
|
||||
class SwishOffset(torch.nn.Module):
|
||||
"""Construct an SwishOffset object."""
|
||||
def __init__(self, offset: float = -1.0) -> None:
|
||||
super(SwishOffset, self).__init__()
|
||||
self.offset = offset
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
"""Return Swich activation function."""
|
||||
return x * torch.sigmoid(x + self.offset)
|
||||
|
||||
|
||||
def identity(x):
|
||||
return x
|
||||
|
@ -110,7 +110,7 @@ def get_parser():
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
default="transducer_stateless/specaugmod_baseline_randcombine1_expscale5_brelu2swish2",
|
||||
default="transducer_stateless/specaugmod_baseline_randcombine1_expscale3_brelu2swish2",
|
||||
help="""The experiment dir.
|
||||
It specifies the directory where all training related
|
||||
files, e.g., checkpoints, log, etc, are saved
|
||||
|
Loading…
x
Reference in New Issue
Block a user