From ba3611cefd1af82ef343beec9daef9d2e795f3a0 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 18 Mar 2022 16:35:48 +0800 Subject: [PATCH] Cosmetic changes to swish --- .../pruned_transducer_stateless2/scaling.py | 28 +++++++++---------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py index f0e3fe148..d03bd0967 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py @@ -310,22 +310,22 @@ class ActivationBalancer(torch.nn.Module): self.max_factor, self.min_abs, self.max_abs) -# deriv of double_swish: -# double_swish(x) = x * torch.sigmoid(x-1) [this is a definition, originally -# motivated by its similarity to swish(swish(x), -# where swish(x) = x *sigmoid(x)]. -# double_swish(x) = x * s, where s(x) = torch.sigmoid(x-1) -# double_swish'(x) = d/dx double_swish(x) = x * s'(x) + x' * s(x) = x * s'(x) + s(x). -# Now, s'(x) = s(x) * (1-s(x)). -# double_swish'(x) = x * s'(x) + s(x). -# = x * s(x) * (1-s(x)) + s(x). -# = double_swish(x) * (1-s(x)) + s(x) - -def _double_swish(x: Tensor) -> Tensor: - # double-swish, implemented/approximated as offset-swish - return x * torch.sigmoid(x - 1.0) class DoubleSwishFunction(torch.autograd.Function): + """ + double_swish(x) = x * torch.sigmoid(x-1) + This is a definition, originally motivated by its close numerical + similarity to swish(swish(x), where swish(x) = x * sigmoid(x). + + Memory-efficient derivative computation: + double_swish(x) = x * s, where s(x) = torch.sigmoid(x-1) + double_swish'(x) = d/dx double_swish(x) = x * s'(x) + x' * s(x) = x * s'(x) + s(x). + Now, s'(x) = s(x) * (1-s(x)). + double_swish'(x) = x * s'(x) + s(x). + = x * s(x) * (1-s(x)) + s(x). + = double_swish(x) * (1-s(x)) + s(x) + ... so we just need to remember s(x) but not x itself. + """ @staticmethod def forward(ctx, x: Tensor) -> Tensor: x = x.detach()