mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Keep just the RandomGrad changes, vs. 149. Git history may not reflect real changes.
This commit is contained in:
commit
610281eaa2
@ -19,12 +19,10 @@ import k2
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from encoder_interface import EncoderInterface
|
||||
from scaling import random_clamp
|
||||
|
||||
from icefall.utils import add_sos
|
||||
|
||||
|
||||
|
||||
class Transducer(nn.Module):
|
||||
"""It implements https://arxiv.org/pdf/1211.3711.pdf
|
||||
"Sequence Transduction with Recurrent Neural Networks"
|
||||
@ -142,12 +140,6 @@ class Transducer(nn.Module):
|
||||
lm = self.simple_lm_proj(decoder_out)
|
||||
am = self.simple_am_proj(encoder_out)
|
||||
|
||||
if self.training:
|
||||
lm = random_clamp(lm, min=-8.0, max=2.0, prob=0.5,
|
||||
reflect=0.1)
|
||||
am = random_clamp(am, min=-5.0, max=5.0, prob=0.5,
|
||||
reflect=0.1)
|
||||
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
|
||||
lm=lm.float(),
|
||||
@ -183,10 +175,6 @@ class Transducer(nn.Module):
|
||||
# prior to do_rnnt_pruning (this is an optimization for speed).
|
||||
logits = self.joiner(am_pruned, lm_pruned, project_input=False)
|
||||
|
||||
if self.training:
|
||||
logits = random_clamp(logits, -8.0, 2.0, prob=0.5,
|
||||
reflect=0.1)
|
||||
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
pruned_loss = k2.rnnt_loss_pruned(
|
||||
logits=logits.float(),
|
||||
|
||||
@ -175,7 +175,6 @@ class RandomClampFunction(torch.autograd.Function):
|
||||
ctx.reflect = reflect
|
||||
if reflect != 0.0:
|
||||
ans = ans * (1.0 + reflect) - (x * reflect)
|
||||
|
||||
return ans
|
||||
|
||||
@staticmethod
|
||||
@ -185,7 +184,7 @@ class RandomClampFunction(torch.autograd.Function):
|
||||
reflect = ctx.reflect
|
||||
if reflect != 0.0:
|
||||
x_grad = x_grad * (1.0 + reflect) - (ans_grad * reflect)
|
||||
return ans_grad * is_same.to(ans_grad.dtype), None, None, None, None
|
||||
return x_grad, None, None, None, None
|
||||
|
||||
def random_clamp(x: Tensor,
|
||||
min: Optional[float] = None,
|
||||
@ -211,6 +210,7 @@ def random_cast_to_half(x: Tensor,
|
||||
random_val = min_abs * x.sign() * (torch.rand_like(x) * min_abs < x_abs)
|
||||
return torch.where(is_too_small, random_val, x).to(torch.float16)
|
||||
|
||||
|
||||
class RandomGradFunction(torch.autograd.Function):
|
||||
"""
|
||||
Does nothing in forward pass; in backward pass, gets rid of very small grads using
|
||||
@ -248,6 +248,7 @@ class RandomGrad(torch.nn.Module):
|
||||
return RandomGradFunction.apply(x, self.min_abs)
|
||||
|
||||
|
||||
|
||||
class SoftmaxFunction(torch.autograd.Function):
|
||||
"""
|
||||
Tries to handle half-precision derivatives in a randomized way that should
|
||||
@ -255,34 +256,33 @@ class SoftmaxFunction(torch.autograd.Function):
|
||||
"""
|
||||
@staticmethod
|
||||
def forward(ctx, x: Tensor, dim: int):
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
if x.dtype == torch.float16:
|
||||
x = x.to(torch.float32)
|
||||
ans = x.softmax(dim=dim)
|
||||
ctx.save_for_backward(ans)
|
||||
ctx.dim = dim
|
||||
return ans
|
||||
ans = x.softmax(dim=dim)
|
||||
# if x dtype is float16, x.softmax() returns a float32 because
|
||||
# (presumably) that op does not support float16, and autocast
|
||||
# is enabled.
|
||||
ctx.save_for_backward(ans)
|
||||
ctx.x_dtype = x.dtype
|
||||
ctx.dim = dim
|
||||
return ans
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, ans_grad: Tensor):
|
||||
ans, = ctx.saved_tensors
|
||||
|
||||
if ans.dtype == torch.float16 or ans_grad.dtype == torch.float16:
|
||||
# use a randomized approach to convert to float16
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
ans_grad = ans_grad.to(torch.float32)
|
||||
ans = ans.to(torch.float32)
|
||||
x_grad = ans_grad * ans
|
||||
x_grad = x_grad - ans * x_grad.sum(dim=ctx.dim, keepdim=True)
|
||||
return random_cast_to_half(x_grad), None
|
||||
else:
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
ans_grad = ans_grad.to(torch.float32)
|
||||
ans = ans.to(torch.float32)
|
||||
x_grad = ans_grad * ans
|
||||
x_grad = x_grad - ans * x_grad.sum(dim=ctx.dim, keepdim=True)
|
||||
if ctx.x_dtype == torch.float16:
|
||||
x_grad = random_cast_to_half(x_grad)
|
||||
|
||||
return x_grad, None
|
||||
|
||||
|
||||
|
||||
def softmax(x: Tensor,
|
||||
dim: int):
|
||||
logging.info(f"torch.is_autocast_enabled()={torch.is_autocast_enabled()}, x dtype={x.dtype}")
|
||||
return SoftmaxFunction.apply(x, dim)
|
||||
|
||||
|
||||
@ -907,7 +907,6 @@ class DoubleSwish(torch.nn.Module):
|
||||
|
||||
|
||||
def _test_max_eig():
|
||||
|
||||
for proportion in [0.1, 0.5, 10.0]:
|
||||
logging.info(f"proportion = {proportion}")
|
||||
x = torch.randn(100, 128)
|
||||
@ -931,7 +930,7 @@ def _test_max_eig():
|
||||
y.backward(gradient=y_grad)
|
||||
|
||||
if proportion < 0.2:
|
||||
assert torch.allclose(x.grad, y_grad)
|
||||
assert torch.allclose(x.grad, y_grad, atol=1.0e-02)
|
||||
elif proportion > 1.0:
|
||||
assert not torch.allclose(x.grad, y_grad)
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user