From 983a690c631fed085e3686bd9d093acbfb2b439c Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 1 Dec 2022 17:20:56 +0800 Subject: [PATCH] Change DoubleSwish formulation, add alpha*x only for x.abs() > 0.15. --- .../pruned_transducer_stateless7/scaling.py | 43 +++++++++++-------- 1 file changed, 26 insertions(+), 17 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py index b11b2311a..11db31dfd 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py @@ -1065,6 +1065,7 @@ class MaxEig(torch.nn.Module): class DoubleSwishFunction(torch.autograd.Function): """ double_swish(x) = x * (torch.sigmoid(x-1) + alpha) + for e.g. alpha=-0.05 (user supplied). This is a definition, originally motivated by its close numerical similarity to swish(swish(x)), where swish(x) = x * sigmoid(x). @@ -1080,26 +1081,36 @@ class DoubleSwishFunction(torch.autograd.Function): """ @staticmethod - def forward(ctx, x: Tensor, alpha: float) -> Tensor: + def forward(ctx, x: Tensor) -> Tensor: requires_grad = x.requires_grad x_dtype = x.dtype - ctx.alpha = alpha if x.dtype == torch.float16: x = x.to(torch.float32) s = torch.sigmoid(x - 1.0) y = x * s + alpha = -0.05 + beta = 0.05 + x_limit = 0.15 + + # another part of this formula is: + # ... + 0.2 * x.clamp(min=-0.15, max=0.15) + # the deriv of this is + # beta * (x.abs() < x_limit). + if requires_grad: - deriv = (y * (1 - s) + s) + deriv = (y * (1 - s) + s) # ignores the alpha part. + deriv = deriv + (x.abs() < x_limit) * beta + # notes on derivative of x * sigmoid(x - 1): # https://www.wolframalpha.com/input?i=d%2Fdx+%28x+*+sigmoid%28x-1%29%29 - # min \simeq -0.043638. Take floor as -0.043637 so it's a lower bund + # min \simeq -0.043638. Take floor as -0.044 so it's a lower bund # max \simeq 1.1990. Take ceil to be 1.2 so it's an upper bound. # the combination of "+ torch.rand_like(deriv)" and casting to torch.uint8 (which # floors), should be expectation-preserving. - floor = -0.043637 - ceil = 1.2 + floor = -0.044 + ceil = 1.2 + beta d_scaled = ((deriv - floor) * (255.0 / (ceil - floor)) + torch.rand_like(deriv)) if __name__ == "__main__": # for self-testing only. @@ -1107,7 +1118,7 @@ class DoubleSwishFunction(torch.autograd.Function): assert d_scaled.max() < 256.0 d_int = d_scaled.to(torch.uint8) ctx.save_for_backward(d_int) - y = y + alpha * x + y = y + alpha * x + beta * x.clamp(min=-x_limit, max=x_limit) if x.dtype == torch.float16 or torch.is_autocast_enabled(): y = y.to(torch.float16) return y @@ -1115,29 +1126,27 @@ class DoubleSwishFunction(torch.autograd.Function): @staticmethod def backward(ctx, y_grad: Tensor) -> Tensor: d, = ctx.saved_tensors - alpha = ctx.alpha # the same constants as used in forward pass. + alpha = -0.05 + beta = 0.05 floor = -0.043637 - ceil = 1.2 + ceil = 1.2 + beta + d = (d * ((ceil - floor) / 255.0) + floor) - return (y_grad * (d + alpha)), None + return (y_grad * (d + alpha)) class DoubleSwish(torch.nn.Module): - def __init__(self, - alpha: float = -0.05): + def __init__(self): super().__init__() - self.alpha = alpha - def extra_repr(self) -> str: - return 'alpha={}'.format(self.alpha) def forward(self, x: Tensor) -> Tensor: """Return double-swish activation function which is an approximation to Swish(Swish(x)), that we approximate closely with x * sigmoid(x-1). """ if torch.jit.is_scripting(): - return x * (torch.sigmoid(x - 1.0) + self.alpha) - return DoubleSwishFunction.apply(x, self.alpha) + return x * (torch.sigmoid(x - 1.0) - 0.05) + 0.05 * x.clamp(min=-0.15, max=0.15) + return DoubleSwishFunction.apply(x) class TanSwishFunction(torch.autograd.Function):