diff --git a/egs/librispeech/ASR/conformer_ctc/subsampling.py b/egs/librispeech/ASR/conformer_ctc/subsampling.py index 3a1eda3f1..1e31c0a20 100644 --- a/egs/librispeech/ASR/conformer_ctc/subsampling.py +++ b/egs/librispeech/ASR/conformer_ctc/subsampling.py @@ -267,23 +267,6 @@ def _exp_scale_relu(x: Tensor, scale: Tensor, speed: float) -> Tensor: return (x * (scale * speed).exp()).relu() -class ExpScaleReluFunction(torch.autograd.Function): - @staticmethod - def forward(ctx, x: Tensor, scale: Tensor, speed: float) -> Tensor: - ctx.save_for_backward(x.detach(), scale.detach()) - ctx.speed = speed - return _exp_scale_swish(x, scale, speed) - - @staticmethod - def backward(ctx, y_grad: Tensor) -> Tensor: - x, scale = ctx.saved_tensors - x.requires_grad = True - scale.requires_grad = True - with torch.enable_grad(): - y = _exp_scale_swish(x, scale, ctx.speed) - y.backward(gradient=y_grad) - return x.grad, scale.grad, None - class ExpScaleReluFunction(torch.autograd.Function): @@ -563,16 +546,32 @@ class DerivBalancer(torch.nn.Module): self.max_abs) +def _double_swish(x: Tensor) -> Tensor: + # double-swish, implemented/approximated as offset-swish + return x * torch.sigmoid(x - 1.0) + +class DoubleSwishFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, x: Tensor) -> Tensor: + ctx.save_for_backward(x.detach()) + return _double_swish(x) + + @staticmethod + def backward(ctx, y_grad: Tensor) -> Tensor: + # TODO: can make this more efficient. + x, = ctx.saved_tensors + x.requires_grad = True + with torch.enable_grad(): + y = _double_swish(x) + y.backward(gradient=y_grad) + return x.grad + class DoubleSwish(torch.nn.Module): def forward(self, x: Tensor) -> Tensor: """Return double-swish activation function which is an approximation to Swish(Swish(x)), - that we approximate closely with x * sigmoid(x-1), expressed for more memory-efficient - backprop as (x-1) * torch.sigmoid(x - 1) + torch.sigmoid(x - 1) + that we approximate closely with x * sigmoid(x-1). """ - x1 = x - 1.0 - s = torch.sigmoid(x1) - return (x1 * s) + s # (x-1) * s + s == x * s - + return DoubleSwishFunction.apply(x) def _test_exp_scale_swish():