diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py index 81f025b79..936a77b8c 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py @@ -1235,7 +1235,7 @@ class SwooshFunction(torch.autograd.Function): with torch.enable_grad(): x = x.detach() x.requires_grad = True - y = torch.logaddexp(one, x - 4) - 0.08 * x - 0.15 + y = torch.logaddexp(one, x - 1.125) - 0.08 * x - 0.3 if not requires_grad: return y @@ -1273,7 +1273,7 @@ class Swoosh(torch.nn.Module): """ if torch.jit.is_scripting(): one = torch.tensor(1.0, dtype=x.dtype, device=x.device) - return torch.logaddexp(one, x - 4) - 0.08 * x - 0.15 + return torch.logaddexp(one, x - 1.125) - 0.08 * x - 0.3 return SwooshFunction.apply(x)