mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Add RandomGrad with min_abs=1.0e-04
This commit is contained in:
parent
0ad4462632
commit
a4443efa95
@ -36,6 +36,7 @@ from scaling import (
|
||||
_diag,
|
||||
random_clamp,
|
||||
softmax,
|
||||
RandomGrad,
|
||||
)
|
||||
from torch import Tensor, nn
|
||||
|
||||
@ -304,7 +305,7 @@ class ConformerEncoderLayer(nn.Module):
|
||||
whitening_limit=5.0,
|
||||
prob=(0.025, 0.25),
|
||||
grad_scale=0.01)
|
||||
|
||||
self.random_grad = RandomGrad()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -364,7 +365,7 @@ class ConformerEncoderLayer(nn.Module):
|
||||
bypass_scale = bypass_scale.clamp(min=0.1, max=1.0)
|
||||
src = src_orig + delta * self.bypass_scale
|
||||
|
||||
return self.whiten(src)
|
||||
return self.random_grad(self.whiten(src))
|
||||
|
||||
|
||||
class ConformerEncoder(nn.Module):
|
||||
|
||||
@ -211,6 +211,42 @@ def random_cast_to_half(x: Tensor,
|
||||
random_val = min_abs * x.sign() * (torch.rand_like(x) * min_abs < x_abs)
|
||||
return torch.where(is_too_small, random_val, x).to(torch.float16)
|
||||
|
||||
class RandomGradFunction(torch.autograd.Function):
|
||||
"""
|
||||
Does nothing in forward pass; in backward pass, gets rid of very small grads using
|
||||
randomized approach that preserves expectations (intended to reduce roundoff).
|
||||
"""
|
||||
@staticmethod
|
||||
def forward(ctx, x: Tensor, min_abs: float) -> Tensor:
|
||||
ctx.min_abs = min_abs
|
||||
return x
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, ans_grad: Tensor) -> Tuple[Tensor, None]:
|
||||
min_abs = ctx.min_abs
|
||||
if ans_grad.dtype == torch.float16:
|
||||
return random_cast_to_half(ans_grad.to(torch.float32),
|
||||
min_abs=ctx.min_abs), None
|
||||
else:
|
||||
return ans_grad, None
|
||||
|
||||
class RandomGrad(torch.nn.Module):
|
||||
"""
|
||||
Gets rid of very small gradients using an expectation-preserving method, intended to increase
|
||||
accuracy of training when using amp (automatic mixed precision)
|
||||
"""
|
||||
def __init__(self,
|
||||
min_abs: float = 1.0e-04):
|
||||
super(RandomGrad, self).__init__()
|
||||
self.min_abs = min_abs
|
||||
|
||||
def forward(self,
|
||||
x: Tensor):
|
||||
if torch.jit.is_scripting() or not self.training:
|
||||
return x
|
||||
else:
|
||||
return RandomGradFunction.apply(x, self.min_abs)
|
||||
|
||||
|
||||
class SoftmaxFunction(torch.autograd.Function):
|
||||
"""
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user