Add RandomGrad with min_abs=1.0e-04

This commit is contained in:
Daniel Povey 2022-10-19 19:46:17 +08:00
parent 0ad4462632
commit a4443efa95
2 changed files with 39 additions and 2 deletions

View File

@ -36,6 +36,7 @@ from scaling import (
_diag, _diag,
random_clamp, random_clamp,
softmax, softmax,
RandomGrad,
) )
from torch import Tensor, nn from torch import Tensor, nn
@ -304,7 +305,7 @@ class ConformerEncoderLayer(nn.Module):
whitening_limit=5.0, whitening_limit=5.0,
prob=(0.025, 0.25), prob=(0.025, 0.25),
grad_scale=0.01) grad_scale=0.01)
self.random_grad = RandomGrad()
def forward( def forward(
self, self,
@ -364,7 +365,7 @@ class ConformerEncoderLayer(nn.Module):
bypass_scale = bypass_scale.clamp(min=0.1, max=1.0) bypass_scale = bypass_scale.clamp(min=0.1, max=1.0)
src = src_orig + delta * self.bypass_scale src = src_orig + delta * self.bypass_scale
return self.whiten(src) return self.random_grad(self.whiten(src))
class ConformerEncoder(nn.Module): class ConformerEncoder(nn.Module):

View File

@ -211,6 +211,42 @@ def random_cast_to_half(x: Tensor,
random_val = min_abs * x.sign() * (torch.rand_like(x) * min_abs < x_abs) random_val = min_abs * x.sign() * (torch.rand_like(x) * min_abs < x_abs)
return torch.where(is_too_small, random_val, x).to(torch.float16) return torch.where(is_too_small, random_val, x).to(torch.float16)
class RandomGradFunction(torch.autograd.Function):
"""
Does nothing in forward pass; in backward pass, gets rid of very small grads using
randomized approach that preserves expectations (intended to reduce roundoff).
"""
@staticmethod
def forward(ctx, x: Tensor, min_abs: float) -> Tensor:
ctx.min_abs = min_abs
return x
@staticmethod
def backward(ctx, ans_grad: Tensor) -> Tuple[Tensor, None]:
min_abs = ctx.min_abs
if ans_grad.dtype == torch.float16:
return random_cast_to_half(ans_grad.to(torch.float32),
min_abs=ctx.min_abs), None
else:
return ans_grad, None
class RandomGrad(torch.nn.Module):
"""
Gets rid of very small gradients using an expectation-preserving method, intended to increase
accuracy of training when using amp (automatic mixed precision)
"""
def __init__(self,
min_abs: float = 1.0e-04):
super(RandomGrad, self).__init__()
self.min_abs = min_abs
def forward(self,
x: Tensor):
if torch.jit.is_scripting() or not self.training:
return x
else:
return RandomGradFunction.apply(x, self.min_abs)
class SoftmaxFunction(torch.autograd.Function): class SoftmaxFunction(torch.autograd.Function):
""" """