mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
Cosmetic changes to swish
This commit is contained in:
parent
6769087d70
commit
ba3611cefd
@ -310,22 +310,22 @@ class ActivationBalancer(torch.nn.Module):
|
|||||||
self.max_factor, self.min_abs,
|
self.max_factor, self.min_abs,
|
||||||
self.max_abs)
|
self.max_abs)
|
||||||
|
|
||||||
# deriv of double_swish:
|
|
||||||
# double_swish(x) = x * torch.sigmoid(x-1) [this is a definition, originally
|
|
||||||
# motivated by its similarity to swish(swish(x),
|
|
||||||
# where swish(x) = x *sigmoid(x)].
|
|
||||||
# double_swish(x) = x * s, where s(x) = torch.sigmoid(x-1)
|
|
||||||
# double_swish'(x) = d/dx double_swish(x) = x * s'(x) + x' * s(x) = x * s'(x) + s(x).
|
|
||||||
# Now, s'(x) = s(x) * (1-s(x)).
|
|
||||||
# double_swish'(x) = x * s'(x) + s(x).
|
|
||||||
# = x * s(x) * (1-s(x)) + s(x).
|
|
||||||
# = double_swish(x) * (1-s(x)) + s(x)
|
|
||||||
|
|
||||||
def _double_swish(x: Tensor) -> Tensor:
|
|
||||||
# double-swish, implemented/approximated as offset-swish
|
|
||||||
return x * torch.sigmoid(x - 1.0)
|
|
||||||
|
|
||||||
class DoubleSwishFunction(torch.autograd.Function):
|
class DoubleSwishFunction(torch.autograd.Function):
|
||||||
|
"""
|
||||||
|
double_swish(x) = x * torch.sigmoid(x-1)
|
||||||
|
This is a definition, originally motivated by its close numerical
|
||||||
|
similarity to swish(swish(x), where swish(x) = x * sigmoid(x).
|
||||||
|
|
||||||
|
Memory-efficient derivative computation:
|
||||||
|
double_swish(x) = x * s, where s(x) = torch.sigmoid(x-1)
|
||||||
|
double_swish'(x) = d/dx double_swish(x) = x * s'(x) + x' * s(x) = x * s'(x) + s(x).
|
||||||
|
Now, s'(x) = s(x) * (1-s(x)).
|
||||||
|
double_swish'(x) = x * s'(x) + s(x).
|
||||||
|
= x * s(x) * (1-s(x)) + s(x).
|
||||||
|
= double_swish(x) * (1-s(x)) + s(x)
|
||||||
|
... so we just need to remember s(x) but not x itself.
|
||||||
|
"""
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(ctx, x: Tensor) -> Tensor:
|
def forward(ctx, x: Tensor) -> Tensor:
|
||||||
x = x.detach()
|
x = x.detach()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user