mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
Set scaling on SwishExpScale
This commit is contained in:
parent
cc558faf26
commit
2d3a76292d
@ -255,7 +255,9 @@ class SwishExpScale(torch.nn.Module):
|
||||
def __init__(self, *shape, speed: float = 1.0, in_scale: float = 1.0):
|
||||
super(SwishExpScale, self).__init__()
|
||||
self.in_scale = in_scale
|
||||
self.scale = nn.Parameter(torch.zeros(*shape))
|
||||
initial_log_scale = torch.tensor(1.0 / in_scale).log() / speed
|
||||
initial_log_scale = (torch.ones(*shape) * initial_log_scale).detach()
|
||||
self.scale = nn.Parameter(initial_log_scale)
|
||||
self.speed = speed
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
|
Loading…
x
Reference in New Issue
Block a user