Replace relu with swish-squared.

This commit is contained in:
Daniel Povey 2022-03-05 22:21:42 +08:00
parent 5f2c0a09b7
commit 8a8b81cd18
3 changed files with 14 additions and 12 deletions

View File

@ -212,12 +212,11 @@ class ExpScale(torch.nn.Module):
def _exp_scale_swish(x: Tensor, scale: Tensor, speed: float) -> Tensor: def _exp_scale_swish(x: Tensor, scale: Tensor, speed: float) -> Tensor:
return (x * torch.sigmoid(x)) * (scale * speed).exp() # double-swish!
x = (x * torch.sigmoid(x))
x = (x * torch.sigmoid(x))
def _exp_scale_swish_backward(x: Tensor, scale: Tensor, speed: float) -> Tensor: x = x * (scale * speed).exp()
return (x * torch.sigmoid(x)) * (scale * speed).exp() return x
class ExpScaleSwishFunction(torch.autograd.Function): class ExpScaleSwishFunction(torch.autograd.Function):
@staticmethod @staticmethod
@ -247,8 +246,11 @@ class ExpScaleSwish(torch.nn.Module):
def forward(self, x: Tensor) -> Tensor: def forward(self, x: Tensor) -> Tensor:
return ExpScaleSwishFunction.apply(x, self.scale, self.speed) return ExpScaleSwishFunction.apply(x, self.scale, self.speed)
# return (x * torch.sigmoid(x)) * (self.scale * self.speed).exp() # x = (x * torch.sigmoid(x))
# return x * (self.scale * self.speed).exp() # x = (x * torch.sigmoid(x))
# x = x * (self.scale * self.speed).exp()
# return x
def _exp_scale_relu(x: Tensor, scale: Tensor, speed: float) -> Tensor: def _exp_scale_relu(x: Tensor, scale: Tensor, speed: float) -> Tensor:

View File

@ -158,7 +158,7 @@ class ConformerEncoderLayer(nn.Module):
nn.Linear(d_model, dim_feedforward), nn.Linear(d_model, dim_feedforward),
DerivBalancer(channel_dim=-1, threshold=0.05, DerivBalancer(channel_dim=-1, threshold=0.05,
max_factor=0.025), max_factor=0.025),
ExpScaleRelu(dim_feedforward, speed=20.0), ExpScaleSwish(dim_feedforward, speed=20.0),
nn.Dropout(dropout), nn.Dropout(dropout),
nn.Linear(dim_feedforward, d_model), nn.Linear(dim_feedforward, d_model),
) )
@ -167,7 +167,7 @@ class ConformerEncoderLayer(nn.Module):
nn.Linear(d_model, dim_feedforward), nn.Linear(d_model, dim_feedforward),
DerivBalancer(channel_dim=-1, threshold=0.05, DerivBalancer(channel_dim=-1, threshold=0.05,
max_factor=0.025), max_factor=0.025),
ExpScaleRelu(dim_feedforward, speed=20.0), ExpScaleSwish(dim_feedforward, speed=20.0),
nn.Dropout(dropout), nn.Dropout(dropout),
nn.Linear(dim_feedforward, d_model), nn.Linear(dim_feedforward, d_model),
) )
@ -880,7 +880,7 @@ class ConvolutionModule(nn.Module):
self.balancer = DerivBalancer(channel_dim=1, threshold=0.05, self.balancer = DerivBalancer(channel_dim=1, threshold=0.05,
max_factor=0.025) max_factor=0.025)
# shape: (channels, 1), broadcasts with (batch, channel, time). # shape: (channels, 1), broadcasts with (batch, channel, time).
self.activation = ExpScaleRelu(channels, 1, speed=20.0) self.activation = ExpScaleSwish(channels, 1, speed=20.0)
self.pointwise_conv2 = nn.Conv1d( self.pointwise_conv2 = nn.Conv1d(
channels, channels,

View File

@ -110,7 +110,7 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--exp-dir", "--exp-dir",
type=str, type=str,
default="transducer_stateless/specaugmod_baseline_randcombine1_expscale5_brelu2relu", default="transducer_stateless/specaugmod_baseline_randcombine1_expscale5_brelu2swish2",
help="""The experiment dir. help="""The experiment dir.
It specifies the directory where all training related It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved files, e.g., checkpoints, log, etc, are saved