fix SwooshR and SwooshL

This commit is contained in:
Yifan Yang 2025-05-12 00:48:42 +08:00 committed by GitHub
parent cd3adad46d
commit 5fbeed9f96
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1403,9 +1403,9 @@ class SwooshL(torch.nn.Module):
zero = torch.tensor(0.0, dtype=x.dtype, device=x.device) zero = torch.tensor(0.0, dtype=x.dtype, device=x.device)
return logaddexp(zero, x - 4.0) - 0.08 * x - 0.035 return logaddexp(zero, x - 4.0) - 0.08 * x - 0.035
if not x.requires_grad: if not x.requires_grad:
return k2.swoosh_l_forward(x) return k2.swoosh_l_forward(x).to(x.dtype)
else: else:
return k2.swoosh_l(x) return k2.swoosh_l(x).to(x.dtype)
# return SwooshLFunction.apply(x) # return SwooshLFunction.apply(x)
@ -1477,9 +1477,9 @@ class SwooshR(torch.nn.Module):
zero = torch.tensor(0.0, dtype=x.dtype, device=x.device) zero = torch.tensor(0.0, dtype=x.dtype, device=x.device)
return logaddexp(zero, x - 1.0) - 0.08 * x - 0.313261687 return logaddexp(zero, x - 1.0) - 0.08 * x - 0.313261687
if not x.requires_grad: if not x.requires_grad:
return k2.swoosh_r_forward(x) return k2.swoosh_r_forward(x).to(x.dtype)
else: else:
return k2.swoosh_r(x) return k2.swoosh_r(x).to(x.dtype)
# return SwooshRFunction.apply(x) # return SwooshRFunction.apply(x)