Cosmetic changes to swish

This commit is contained in:
Daniel Povey 2022-03-18 16:35:48 +08:00
parent 6769087d70
commit ba3611cefd

View File

@ -310,22 +310,22 @@ class ActivationBalancer(torch.nn.Module):
self.max_factor, self.min_abs, self.max_factor, self.min_abs,
self.max_abs) self.max_abs)
# deriv of double_swish:
# double_swish(x) = x * torch.sigmoid(x-1) [this is a definition, originally
# motivated by its similarity to swish(swish(x),
# where swish(x) = x *sigmoid(x)].
# double_swish(x) = x * s, where s(x) = torch.sigmoid(x-1)
# double_swish'(x) = d/dx double_swish(x) = x * s'(x) + x' * s(x) = x * s'(x) + s(x).
# Now, s'(x) = s(x) * (1-s(x)).
# double_swish'(x) = x * s'(x) + s(x).
# = x * s(x) * (1-s(x)) + s(x).
# = double_swish(x) * (1-s(x)) + s(x)
def _double_swish(x: Tensor) -> Tensor:
# double-swish, implemented/approximated as offset-swish
return x * torch.sigmoid(x - 1.0)
class DoubleSwishFunction(torch.autograd.Function): class DoubleSwishFunction(torch.autograd.Function):
"""
double_swish(x) = x * torch.sigmoid(x-1)
This is a definition, originally motivated by its close numerical
similarity to swish(swish(x), where swish(x) = x * sigmoid(x).
Memory-efficient derivative computation:
double_swish(x) = x * s, where s(x) = torch.sigmoid(x-1)
double_swish'(x) = d/dx double_swish(x) = x * s'(x) + x' * s(x) = x * s'(x) + s(x).
Now, s'(x) = s(x) * (1-s(x)).
double_swish'(x) = x * s'(x) + s(x).
= x * s(x) * (1-s(x)) + s(x).
= double_swish(x) * (1-s(x)) + s(x)
... so we just need to remember s(x) but not x itself.
"""
@staticmethod @staticmethod
def forward(ctx, x: Tensor) -> Tensor: def forward(ctx, x: Tensor) -> Tensor:
x = x.detach() x = x.detach()