From e79833aad278f09792deceab5962b09ae4f56378 Mon Sep 17 00:00:00 2001 From: Yifan Yang <64255737+yfyeung@users.noreply.github.com> Date: Mon, 12 May 2025 19:28:48 +0800 Subject: [PATCH] ensure SwooshL/SwooshR output dtype matches input dtype (#1940) --- egs/librispeech/ASR/zipformer/scaling.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index 6d6281903..11375385e 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -1403,9 +1403,9 @@ class SwooshL(torch.nn.Module): zero = torch.tensor(0.0, dtype=x.dtype, device=x.device) return logaddexp(zero, x - 4.0) - 0.08 * x - 0.035 if not x.requires_grad: - return k2.swoosh_l_forward(x) + return k2.swoosh_l_forward(x).to(x.dtype) else: - return k2.swoosh_l(x) + return k2.swoosh_l(x).to(x.dtype) # return SwooshLFunction.apply(x) @@ -1477,9 +1477,9 @@ class SwooshR(torch.nn.Module): zero = torch.tensor(0.0, dtype=x.dtype, device=x.device) return logaddexp(zero, x - 1.0) - 0.08 * x - 0.313261687 if not x.requires_grad: - return k2.swoosh_r_forward(x) + return k2.swoosh_r_forward(x).to(x.dtype) else: - return k2.swoosh_r(x) + return k2.swoosh_r(x).to(x.dtype) # return SwooshRFunction.apply(x)