Merge branch 'scaled_adam_exp188' into scaled_adam_exp198b

This commit is contained in:
Daniel Povey 2022-10-28 12:49:36 +08:00
commit 7b8a0108ea

View File

@ -201,7 +201,6 @@ def random_cast_to_half(x: Tensor,
"""
if x.dtype == torch.float16:
return x
x_sign = x.sign()
x_abs = x.abs()
is_too_small = (x_abs < min_abs)
# for elements where is_too_small is true, random_val will contain +-min_abs with
@ -223,7 +222,6 @@ class RandomGradFunction(torch.autograd.Function):
@staticmethod
def backward(ctx, ans_grad: Tensor) -> Tuple[Tensor, None]:
min_abs = ctx.min_abs
if ans_grad.dtype == torch.float16:
return random_cast_to_half(ans_grad.to(torch.float32),
min_abs=ctx.min_abs), None