diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py index 565990708..52aa66bc3 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py @@ -36,6 +36,7 @@ from scaling import ( _diag, random_clamp, softmax, + RandomGrad, ) from torch import Tensor, nn @@ -304,7 +305,7 @@ class ConformerEncoderLayer(nn.Module): whitening_limit=5.0, prob=(0.025, 0.25), grad_scale=0.01) - + self.random_grad = RandomGrad() def forward( self, @@ -364,7 +365,7 @@ class ConformerEncoderLayer(nn.Module): bypass_scale = bypass_scale.clamp(min=0.1, max=1.0) src = src_orig + delta * self.bypass_scale - return self.whiten(src) + return self.random_grad(self.whiten(src)) class ConformerEncoder(nn.Module): diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py index 773bab4e9..23afb4387 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py @@ -211,6 +211,42 @@ def random_cast_to_half(x: Tensor, random_val = min_abs * x.sign() * (torch.rand_like(x) * min_abs < x_abs) return torch.where(is_too_small, random_val, x).to(torch.float16) +class RandomGradFunction(torch.autograd.Function): + """ + Does nothing in forward pass; in backward pass, gets rid of very small grads using + randomized approach that preserves expectations (intended to reduce roundoff). + """ + @staticmethod + def forward(ctx, x: Tensor, min_abs: float) -> Tensor: + ctx.min_abs = min_abs + return x + + @staticmethod + def backward(ctx, ans_grad: Tensor) -> Tuple[Tensor, None]: + min_abs = ctx.min_abs + if ans_grad.dtype == torch.float16: + return random_cast_to_half(ans_grad.to(torch.float32), + min_abs=ctx.min_abs), None + else: + return ans_grad, None + +class RandomGrad(torch.nn.Module): + """ + Gets rid of very small gradients using an expectation-preserving method, intended to increase + accuracy of training when using amp (automatic mixed precision) + """ + def __init__(self, + min_abs: float = 1.0e-04): + super(RandomGrad, self).__init__() + self.min_abs = min_abs + + def forward(self, + x: Tensor): + if torch.jit.is_scripting() or not self.training: + return x + else: + return RandomGradFunction.apply(x, self.min_abs) + class SoftmaxFunction(torch.autograd.Function): """