From d682ecc2463af5eeb8f5615362b3ad1d3abaaca2 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 30 Nov 2022 18:58:25 +0800 Subject: [PATCH] Introduce alpha for DoubleSwish, set it to -0.05. --- .../pruned_transducer_stateless7/scaling.py | 22 ++++++++++++++----- 1 file changed, 17 insertions(+), 5 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py index f685bf112..b11b2311a 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py @@ -1064,7 +1064,8 @@ class MaxEig(torch.nn.Module): class DoubleSwishFunction(torch.autograd.Function): """ - double_swish(x) = x * torch.sigmoid(x-1) + double_swish(x) = x * (torch.sigmoid(x-1) + alpha) + for e.g. alpha=-0.05 (user supplied). This is a definition, originally motivated by its close numerical similarity to swish(swish(x)), where swish(x) = x * sigmoid(x). @@ -1079,9 +1080,10 @@ class DoubleSwishFunction(torch.autograd.Function): """ @staticmethod - def forward(ctx, x: Tensor) -> Tensor: + def forward(ctx, x: Tensor, alpha: float) -> Tensor: requires_grad = x.requires_grad x_dtype = x.dtype + ctx.alpha = alpha if x.dtype == torch.float16: x = x.to(torch.float32) @@ -1105,6 +1107,7 @@ class DoubleSwishFunction(torch.autograd.Function): assert d_scaled.max() < 256.0 d_int = d_scaled.to(torch.uint8) ctx.save_for_backward(d_int) + y = y + alpha * x if x.dtype == torch.float16 or torch.is_autocast_enabled(): y = y.to(torch.float16) return y @@ -1112,20 +1115,29 @@ class DoubleSwishFunction(torch.autograd.Function): @staticmethod def backward(ctx, y_grad: Tensor) -> Tensor: d, = ctx.saved_tensors + alpha = ctx.alpha # the same constants as used in forward pass. floor = -0.043637 ceil = 1.2 d = (d * ((ceil - floor) / 255.0) + floor) - return (y_grad * d) + return (y_grad * (d + alpha)), None class DoubleSwish(torch.nn.Module): + def __init__(self, + alpha: float = -0.05): + super().__init__() + self.alpha = alpha + + def extra_repr(self) -> str: + return 'alpha={}'.format(self.alpha) + def forward(self, x: Tensor) -> Tensor: """Return double-swish activation function which is an approximation to Swish(Swish(x)), that we approximate closely with x * sigmoid(x-1). """ if torch.jit.is_scripting(): - return x * torch.sigmoid(x - 1.0) - return DoubleSwishFunction.apply(x) + return x * (torch.sigmoid(x - 1.0) + self.alpha) + return DoubleSwishFunction.apply(x, self.alpha) class TanSwishFunction(torch.autograd.Function):