diff --git a/egs/librispeech/ASR/conformer_ctc/subsampling.py b/egs/librispeech/ASR/conformer_ctc/subsampling.py index 3b35c2ebe..600156bf1 100644 --- a/egs/librispeech/ASR/conformer_ctc/subsampling.py +++ b/egs/librispeech/ASR/conformer_ctc/subsampling.py @@ -206,3 +206,70 @@ class ExpScale(torch.nn.Module): def forward(self, x: Tensor) -> Tensor: return x * (self.scale * self.speed).exp() + + + +def _exp_scale_swish(x: Tensor, scale: Tensor, speed: float) -> Tensor: + return (x * torch.sigmoid(x)) * (scale * speed).exp() + + +def _exp_scale_swish_backward(x: Tensor, scale: Tensor, speed: float) -> Tensor: + return (x * torch.sigmoid(x)) * (scale * speed).exp() + + +class ExpScaleSwishFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, x: Tensor, scale: Tensor, speed: float) -> Tensor: + ctx.save_for_backward(x, scale) + ctx.speed = speed + return _exp_scale_swish(x, scale, speed) + + @staticmethod + def backward(ctx, y_grad: Tensor) -> Tensor: + x, scale = ctx.saved_tensors + x.requires_grad = True + scale.requires_grad = True + with torch.enable_grad(): + y = _exp_scale_swish(x, scale, ctx.speed) + y.backward(gradient=y_grad) + return x.grad, scale.grad, None + + +class ExpScaleSwish(torch.nn.Module): + # combines ExpScale an Swish + # caution: need to specify name for speed, e.g. ExpScaleSwish(50, speed=4.0) + def __init__(self, *shape, speed: float = 1.0): + super(ExpScaleSwish, self).__init__() + self.scale = nn.Parameter(torch.zeros(*shape)) + self.speed = speed + + def forward(self, x: Tensor) -> Tensor: + return ExpScaleSwishFunction.apply(x, self.scale, self.speed) + # return (x * torch.sigmoid(x)) * (self.scale * self.speed).exp() + # return x * (self.scale * self.speed).exp() + +def _test_exp_scale_swish(): + class Swish(torch.nn.Module): + def forward(self, x: Tensor) -> Tensor: + """Return Swich activation function.""" + return x * torch.sigmoid(x) + + x1 = torch.randn(50, 60).detach() + x2 = x1.detach() + + m1 = ExpScaleSwish(50, 1, speed=4.0) + m2 = torch.nn.Sequential(Swish(), ExpScale(50, 1, speed=4.0)) + x1.requires_grad = True + x2.requires_grad = True + + y1 = m1(x1) + y2 = m2(x2) + assert torch.allclose(y1, y2) + y1.sum().backward() + y2.sum().backward() + assert torch.allclose(x1.grad, x2.grad) + + + +if __name__ == '__main__': + _test_exp_scale_swish() diff --git a/egs/librispeech/ASR/transducer_stateless/conformer.py b/egs/librispeech/ASR/transducer_stateless/conformer.py index 59f317e90..3386ed9b2 100644 --- a/egs/librispeech/ASR/transducer_stateless/conformer.py +++ b/egs/librispeech/ASR/transducer_stateless/conformer.py @@ -156,8 +156,7 @@ class ConformerEncoderLayer(nn.Module): self.feed_forward = nn.Sequential( nn.Linear(d_model, dim_feedforward), - Swish(), - ExpScale(dim_feedforward, speed=4.0), + ExpScaleSwish(dim_feedforward, speed=4.0), nn.Dropout(dropout), nn.Linear(dim_feedforward, d_model), ) @@ -165,7 +164,7 @@ class ConformerEncoderLayer(nn.Module): self.feed_forward_macaron = nn.Sequential( nn.Linear(d_model, dim_feedforward), Swish(), - ExpScale(dim_feedforward, speed=4.0), + ExpScaleSwish(dim_feedforward, speed=4.0), nn.Dropout(dropout), nn.Linear(dim_feedforward, d_model), )