Set scaling on SwishExpScale

This commit is contained in:
Daniel Povey 2022-03-11 20:12:45 +08:00
parent cc558faf26
commit 2d3a76292d

View File

@ -255,7 +255,9 @@ class SwishExpScale(torch.nn.Module):
def __init__(self, *shape, speed: float = 1.0, in_scale: float = 1.0):
super(SwishExpScale, self).__init__()
self.in_scale = in_scale
self.scale = nn.Parameter(torch.zeros(*shape))
initial_log_scale = torch.tensor(1.0 / in_scale).log() / speed
initial_log_scale = (torch.ones(*shape) * initial_log_scale).detach()
self.scale = nn.Parameter(initial_log_scale)
self.speed = speed
def forward(self, x: Tensor) -> Tensor: