Fix bug one versus zero

This commit is contained in:
Daniel Povey 2022-12-02 19:12:18 +08:00
parent 2bfc38207c
commit 9a2a58e20d

View File

@ -1229,13 +1229,13 @@ class SwooshFunction(torch.autograd.Function):
if x.dtype == torch.float16:
x = x.to(torch.float32)
one = torch.tensor(1.0, dtype=x.dtype, device=x.device)
zero = torch.tensor(0.0, dtype=x.dtype, device=x.device)
with torch.cuda.amp.autocast(enabled=False):
with torch.enable_grad():
x = x.detach()
x.requires_grad = True
y = torch.logaddexp(one, x - 1.125) - 0.08 * x - 0.3
y = torch.logaddexp(zero, x - 1.125) - 0.08 * x - 0.3
if not requires_grad:
return y