diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py index 32bfab990..67671c4a5 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py @@ -1216,7 +1216,7 @@ class TanSwish(torch.nn.Module): class SwooshLFunction(torch.autograd.Function): """ - swoosh(x) = log(1 + exp(x-1)) - 0.08*x - 0.313261687 + swoosh(x) = log(1 + exp(x-4)) - 0.08*x - 0.035 derivatives are between -0.08 and 0.92. """ @@ -1235,7 +1235,7 @@ class SwooshLFunction(torch.autograd.Function): with torch.enable_grad(): x = x.detach() x.requires_grad = True - y = torch.logaddexp(zero, x - 4.0) - 0.1 * x - 0.035 + y = torch.logaddexp(zero, x - 4.0) - 0.08 * x - 0.035 if not requires_grad: return y @@ -1273,7 +1273,7 @@ class SwooshL(torch.nn.Module): """ if torch.jit.is_scripting(): zero = torch.tensor(0.0, dtype=x.dtype, device=x.device) - return torch.logaddexp(zero, x - 4.0) - 0.1 * x - 0.035 + return torch.logaddexp(zero, x - 4.0) - 0.08 * x - 0.035 return SwooshLFunction.apply(x)