mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
DoubleSwish fix
This commit is contained in:
parent
be0a79cbca
commit
2117f46361
@ -537,13 +537,13 @@ class DerivBalancer(torch.nn.Module):
|
|||||||
self.max_factor, self.min_abs)
|
self.max_factor, self.min_abs)
|
||||||
|
|
||||||
|
|
||||||
|
class DoubleSwish(torch.nn.Module):
|
||||||
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
|
"""Return Swich activation function."""
|
||||||
|
return x * torch.sigmoid(x - 1.0)
|
||||||
|
|
||||||
|
|
||||||
def _test_exp_scale_swish():
|
def _test_exp_scale_swish():
|
||||||
class DoubleSwish(torch.nn.Module):
|
|
||||||
def forward(self, x: Tensor) -> Tensor:
|
|
||||||
"""Return Swich activation function."""
|
|
||||||
return x * torch.sigmoid(x - 1.0)
|
|
||||||
|
|
||||||
x1 = torch.randn(50, 60).detach()
|
x1 = torch.randn(50, 60).detach()
|
||||||
x2 = x1.detach()
|
x2 = x1.detach()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user