mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Add RandomGrad with min_abs=1.0e-04
This commit is contained in:
parent
0ad4462632
commit
a4443efa95
@ -36,6 +36,7 @@ from scaling import (
|
|||||||
_diag,
|
_diag,
|
||||||
random_clamp,
|
random_clamp,
|
||||||
softmax,
|
softmax,
|
||||||
|
RandomGrad,
|
||||||
)
|
)
|
||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
|
|
||||||
@ -304,7 +305,7 @@ class ConformerEncoderLayer(nn.Module):
|
|||||||
whitening_limit=5.0,
|
whitening_limit=5.0,
|
||||||
prob=(0.025, 0.25),
|
prob=(0.025, 0.25),
|
||||||
grad_scale=0.01)
|
grad_scale=0.01)
|
||||||
|
self.random_grad = RandomGrad()
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -364,7 +365,7 @@ class ConformerEncoderLayer(nn.Module):
|
|||||||
bypass_scale = bypass_scale.clamp(min=0.1, max=1.0)
|
bypass_scale = bypass_scale.clamp(min=0.1, max=1.0)
|
||||||
src = src_orig + delta * self.bypass_scale
|
src = src_orig + delta * self.bypass_scale
|
||||||
|
|
||||||
return self.whiten(src)
|
return self.random_grad(self.whiten(src))
|
||||||
|
|
||||||
|
|
||||||
class ConformerEncoder(nn.Module):
|
class ConformerEncoder(nn.Module):
|
||||||
|
|||||||
@ -211,6 +211,42 @@ def random_cast_to_half(x: Tensor,
|
|||||||
random_val = min_abs * x.sign() * (torch.rand_like(x) * min_abs < x_abs)
|
random_val = min_abs * x.sign() * (torch.rand_like(x) * min_abs < x_abs)
|
||||||
return torch.where(is_too_small, random_val, x).to(torch.float16)
|
return torch.where(is_too_small, random_val, x).to(torch.float16)
|
||||||
|
|
||||||
|
class RandomGradFunction(torch.autograd.Function):
|
||||||
|
"""
|
||||||
|
Does nothing in forward pass; in backward pass, gets rid of very small grads using
|
||||||
|
randomized approach that preserves expectations (intended to reduce roundoff).
|
||||||
|
"""
|
||||||
|
@staticmethod
|
||||||
|
def forward(ctx, x: Tensor, min_abs: float) -> Tensor:
|
||||||
|
ctx.min_abs = min_abs
|
||||||
|
return x
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def backward(ctx, ans_grad: Tensor) -> Tuple[Tensor, None]:
|
||||||
|
min_abs = ctx.min_abs
|
||||||
|
if ans_grad.dtype == torch.float16:
|
||||||
|
return random_cast_to_half(ans_grad.to(torch.float32),
|
||||||
|
min_abs=ctx.min_abs), None
|
||||||
|
else:
|
||||||
|
return ans_grad, None
|
||||||
|
|
||||||
|
class RandomGrad(torch.nn.Module):
|
||||||
|
"""
|
||||||
|
Gets rid of very small gradients using an expectation-preserving method, intended to increase
|
||||||
|
accuracy of training when using amp (automatic mixed precision)
|
||||||
|
"""
|
||||||
|
def __init__(self,
|
||||||
|
min_abs: float = 1.0e-04):
|
||||||
|
super(RandomGrad, self).__init__()
|
||||||
|
self.min_abs = min_abs
|
||||||
|
|
||||||
|
def forward(self,
|
||||||
|
x: Tensor):
|
||||||
|
if torch.jit.is_scripting() or not self.training:
|
||||||
|
return x
|
||||||
|
else:
|
||||||
|
return RandomGradFunction.apply(x, self.min_abs)
|
||||||
|
|
||||||
|
|
||||||
class SoftmaxFunction(torch.autograd.Function):
|
class SoftmaxFunction(torch.autograd.Function):
|
||||||
"""
|
"""
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user