Merge branch 'scaled_adam_exp198b' into scaled_adam_exp202

This commit is contained in:
Daniel Povey 2022-10-28 13:13:55 +08:00
commit e592a920b4

View File

@ -201,7 +201,6 @@ def random_cast_to_half(x: Tensor,
"""
if x.dtype == torch.float16:
return x
x_sign = x.sign()
x_abs = x.abs()
is_too_small = (x_abs < min_abs)
# for elements where is_too_small is true, random_val will contain +-min_abs with
@ -223,7 +222,6 @@ class RandomGradFunction(torch.autograd.Function):
@staticmethod
def backward(ctx, ans_grad: Tensor) -> Tuple[Tensor, None]:
min_abs = ctx.min_abs
if ans_grad.dtype == torch.float16:
return random_cast_to_half(ans_grad.to(torch.float32),
min_abs=ctx.min_abs), None
@ -384,8 +382,7 @@ class BasicNorm(torch.nn.Module):
# region if it happens to exit it.
eps = eps.clamp(min=self.eps_min, max=self.eps_max)
scales = (
torch.mean(x ** 2, dim=self.channel_dim, keepdim=True)
+ self.eps.exp()
torch.mean(x ** 2, dim=self.channel_dim, keepdim=True) + eps.exp()
) ** -0.5
return x * scales