Cast to float16 in DoubleSwish forward

This commit is contained in:
Daniel Povey 2022-10-25 13:16:00 +08:00
parent 3159b09e8f
commit dbfbd8016b

View File

@ -939,6 +939,8 @@ class DoubleSwishFunction(torch.autograd.Function):
assert d_scaled.max() < 256.0
d_int = d_scaled.to(torch.uint8)
ctx.save_for_backward(d_int)
if x.dtype == torch.float16 or torch.is_autocast_enabled():
y = y.to(torch.float16)
return y
@staticmethod