diff --git a/egs/librispeech/ASR/conformer_ctc/subsampling.py b/egs/librispeech/ASR/conformer_ctc/subsampling.py index a500e42a9..daf8fd251 100644 --- a/egs/librispeech/ASR/conformer_ctc/subsampling.py +++ b/egs/librispeech/ASR/conformer_ctc/subsampling.py @@ -212,12 +212,11 @@ class ExpScale(torch.nn.Module): def _exp_scale_swish(x: Tensor, scale: Tensor, speed: float) -> Tensor: - return (x * torch.sigmoid(x)) * (scale * speed).exp() - - -def _exp_scale_swish_backward(x: Tensor, scale: Tensor, speed: float) -> Tensor: - return (x * torch.sigmoid(x)) * (scale * speed).exp() - + # double-swish! + x = (x * torch.sigmoid(x)) + x = (x * torch.sigmoid(x)) + x = x * (scale * speed).exp() + return x class ExpScaleSwishFunction(torch.autograd.Function): @staticmethod @@ -247,8 +246,11 @@ class ExpScaleSwish(torch.nn.Module): def forward(self, x: Tensor) -> Tensor: return ExpScaleSwishFunction.apply(x, self.scale, self.speed) - # return (x * torch.sigmoid(x)) * (self.scale * self.speed).exp() - # return x * (self.scale * self.speed).exp() + # x = (x * torch.sigmoid(x)) + # x = (x * torch.sigmoid(x)) + # x = x * (self.scale * self.speed).exp() + # return x + def _exp_scale_relu(x: Tensor, scale: Tensor, speed: float) -> Tensor: diff --git a/egs/librispeech/ASR/transducer_stateless/conformer.py b/egs/librispeech/ASR/transducer_stateless/conformer.py index 7af145a1e..5adb7ca4e 100644 --- a/egs/librispeech/ASR/transducer_stateless/conformer.py +++ b/egs/librispeech/ASR/transducer_stateless/conformer.py @@ -158,7 +158,7 @@ class ConformerEncoderLayer(nn.Module): nn.Linear(d_model, dim_feedforward), DerivBalancer(channel_dim=-1, threshold=0.05, max_factor=0.025), - ExpScaleRelu(dim_feedforward, speed=20.0), + ExpScaleSwish(dim_feedforward, speed=20.0), nn.Dropout(dropout), nn.Linear(dim_feedforward, d_model), ) @@ -167,7 +167,7 @@ class ConformerEncoderLayer(nn.Module): nn.Linear(d_model, dim_feedforward), DerivBalancer(channel_dim=-1, threshold=0.05, max_factor=0.025), - ExpScaleRelu(dim_feedforward, speed=20.0), + ExpScaleSwish(dim_feedforward, speed=20.0), nn.Dropout(dropout), nn.Linear(dim_feedforward, d_model), ) @@ -880,7 +880,7 @@ class ConvolutionModule(nn.Module): self.balancer = DerivBalancer(channel_dim=1, threshold=0.05, max_factor=0.025) # shape: (channels, 1), broadcasts with (batch, channel, time). - self.activation = ExpScaleRelu(channels, 1, speed=20.0) + self.activation = ExpScaleSwish(channels, 1, speed=20.0) self.pointwise_conv2 = nn.Conv1d( channels, diff --git a/egs/librispeech/ASR/transducer_stateless/train.py b/egs/librispeech/ASR/transducer_stateless/train.py index b1cb6d043..a3eca26c9 100755 --- a/egs/librispeech/ASR/transducer_stateless/train.py +++ b/egs/librispeech/ASR/transducer_stateless/train.py @@ -110,7 +110,7 @@ def get_parser(): parser.add_argument( "--exp-dir", type=str, - default="transducer_stateless/specaugmod_baseline_randcombine1_expscale5_brelu2relu", + default="transducer_stateless/specaugmod_baseline_randcombine1_expscale5_brelu2swish2", help="""The experiment dir. It specifies the directory where all training related files, e.g., checkpoints, log, etc, are saved