From 2d3a76292d0649a358f39835cd5944c0ac406b37 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 11 Mar 2022 20:12:45 +0800 Subject: [PATCH] Set scaling on SwishExpScale --- egs/librispeech/ASR/conformer_ctc/subsampling.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/conformer_ctc/subsampling.py b/egs/librispeech/ASR/conformer_ctc/subsampling.py index 52a58d104..caac230ed 100644 --- a/egs/librispeech/ASR/conformer_ctc/subsampling.py +++ b/egs/librispeech/ASR/conformer_ctc/subsampling.py @@ -255,7 +255,9 @@ class SwishExpScale(torch.nn.Module): def __init__(self, *shape, speed: float = 1.0, in_scale: float = 1.0): super(SwishExpScale, self).__init__() self.in_scale = in_scale - self.scale = nn.Parameter(torch.zeros(*shape)) + initial_log_scale = torch.tensor(1.0 / in_scale).log() / speed + initial_log_scale = (torch.ones(*shape) * initial_log_scale).detach() + self.scale = nn.Parameter(initial_log_scale) self.speed = speed def forward(self, x: Tensor) -> Tensor: