mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
Set scaling on SwishExpScale
This commit is contained in:
parent
cc558faf26
commit
2d3a76292d
@ -255,7 +255,9 @@ class SwishExpScale(torch.nn.Module):
|
|||||||
def __init__(self, *shape, speed: float = 1.0, in_scale: float = 1.0):
|
def __init__(self, *shape, speed: float = 1.0, in_scale: float = 1.0):
|
||||||
super(SwishExpScale, self).__init__()
|
super(SwishExpScale, self).__init__()
|
||||||
self.in_scale = in_scale
|
self.in_scale = in_scale
|
||||||
self.scale = nn.Parameter(torch.zeros(*shape))
|
initial_log_scale = torch.tensor(1.0 / in_scale).log() / speed
|
||||||
|
initial_log_scale = (torch.ones(*shape) * initial_log_scale).detach()
|
||||||
|
self.scale = nn.Parameter(initial_log_scale)
|
||||||
self.speed = speed
|
self.speed = speed
|
||||||
|
|
||||||
def forward(self, x: Tensor) -> Tensor:
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user