ensure SwooshL/SwooshR output dtype matches input dtype (#1940)

This commit is contained in:
Yifan Yang 2025-05-12 19:28:48 +08:00 committed by GitHub
parent 4627969ccd
commit e79833aad2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1403,9 +1403,9 @@ class SwooshL(torch.nn.Module):
zero = torch.tensor(0.0, dtype=x.dtype, device=x.device)
return logaddexp(zero, x - 4.0) - 0.08 * x - 0.035
if not x.requires_grad:
return k2.swoosh_l_forward(x)
return k2.swoosh_l_forward(x).to(x.dtype)
else:
return k2.swoosh_l(x)
return k2.swoosh_l(x).to(x.dtype)
# return SwooshLFunction.apply(x)
@ -1477,9 +1477,9 @@ class SwooshR(torch.nn.Module):
zero = torch.tensor(0.0, dtype=x.dtype, device=x.device)
return logaddexp(zero, x - 1.0) - 0.08 * x - 0.313261687
if not x.requires_grad:
return k2.swoosh_r_forward(x)
return k2.swoosh_r_forward(x).to(x.dtype)
else:
return k2.swoosh_r(x)
return k2.swoosh_r(x).to(x.dtype)
# return SwooshRFunction.apply(x)