DoubleSwish fix

This commit is contained in:
Daniel Povey 2022-03-12 19:02:14 +08:00
parent be0a79cbca
commit 2117f46361

View File

@ -537,14 +537,14 @@ class DerivBalancer(torch.nn.Module):
self.max_factor, self.min_abs)
def _test_exp_scale_swish():
class DoubleSwish(torch.nn.Module):
class DoubleSwish(torch.nn.Module):
def forward(self, x: Tensor) -> Tensor:
"""Return Swich activation function."""
return x * torch.sigmoid(x - 1.0)
def _test_exp_scale_swish():
x1 = torch.randn(50, 60).detach()
x2 = x1.detach()