mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-08 09:32:20 +00:00
ensure SwooshL/SwooshR output dtype matches input dtype (#1940)
This commit is contained in:
parent
4627969ccd
commit
e79833aad2
@ -1403,9 +1403,9 @@ class SwooshL(torch.nn.Module):
|
||||
zero = torch.tensor(0.0, dtype=x.dtype, device=x.device)
|
||||
return logaddexp(zero, x - 4.0) - 0.08 * x - 0.035
|
||||
if not x.requires_grad:
|
||||
return k2.swoosh_l_forward(x)
|
||||
return k2.swoosh_l_forward(x).to(x.dtype)
|
||||
else:
|
||||
return k2.swoosh_l(x)
|
||||
return k2.swoosh_l(x).to(x.dtype)
|
||||
# return SwooshLFunction.apply(x)
|
||||
|
||||
|
||||
@ -1477,9 +1477,9 @@ class SwooshR(torch.nn.Module):
|
||||
zero = torch.tensor(0.0, dtype=x.dtype, device=x.device)
|
||||
return logaddexp(zero, x - 1.0) - 0.08 * x - 0.313261687
|
||||
if not x.requires_grad:
|
||||
return k2.swoosh_r_forward(x)
|
||||
return k2.swoosh_r_forward(x).to(x.dtype)
|
||||
else:
|
||||
return k2.swoosh_r(x)
|
||||
return k2.swoosh_r(x).to(x.dtype)
|
||||
# return SwooshRFunction.apply(x)
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user