Implement randomized backprop for softmax.

This commit is contained in:
Daniel Povey 2022-10-19 19:16:03 +08:00
parent c3c655d0bd
commit 9c54906e63
2 changed files with 68 additions and 3 deletions

View File

@ -34,7 +34,8 @@ from scaling import (
Whiten,
Identity,
_diag,
random_clamp
random_clamp,
softmax,
)
from torch import Tensor, nn
@ -1148,7 +1149,7 @@ class RelPositionMultiheadAttention(nn.Module):
bsz * num_heads, seq_len, seq_len
)
attn_output_weights = nn.functional.softmax(attn_output_weights, dim=-1)
attn_output_weights = softmax(attn_output_weights, dim=-1)
attn_output_weights = nn.functional.dropout(
attn_output_weights, p=dropout_p, training=training
)
@ -1561,7 +1562,7 @@ class AttentionCombine(nn.Module):
single_prob_mask)
weights = weights.masked_fill(mask, float('-inf'))
weights = weights.softmax(dim=1)
weights = softmax(weights, dim=1)
# (num_frames, num_channels, num_inputs) * (num_frames, num_inputs, 1) -> (num_frames, num_channels, 1),
ans = torch.matmul(stacked_inputs, weights.unsqueeze(2))

View File

@ -185,6 +185,57 @@ def random_clamp(x: Tensor,
return RandomClampFunction.apply(x, min, max, prob)
def random_cast_to_half(x: Tensor,
min_abs: float = 1.0e-03) -> Tensor:
"""
A randomized way of casting a floating point value to half precision.
"""
if x.dtype == torch.float16:
return x
x_sign = x.sign()
x_abs = x.abs()
is_too_small = (x_abs < min_abs)
# for elements where is_too_small is true, random_val will contain +-min_abs with
# probability (x.abs() / min_abs), and 0.0 otherwise. [so this preserves expectations,
# for those elements].
random_val = min_abs * x.sign() * (torch.rand_like(x) * min_abs < x_abs)
return torch.where(is_too_small, random_val, x).to(torch.float16)
class SoftmaxFunction(torch.autograd.Function):
"""
Tries to handle half-precision derivatives in a randomized way that should
be more accurate for training than the default behavior.
"""
@staticmethod
def forward(ctx, x: Tensor, dim: int):
ans = x.softmax(dim=dim)
ctx.save_for_backward(ans)
ctx.dim = dim
return ans
@staticmethod
def backward(ctx, ans_grad: Tensor):
ans, = ctx.saved_tensors
if ans.dtype == torch.float16 or ans_grad.dtype == torch.float16:
# use a randomized approach to convert to float16
with torch.cuda.amp.autocast(enabled=False):
ans_grad = ans_grad.to(torch.float32)
ans = ans.to(torch.float32)
x_grad = ans_grad * ans
x_grad = x_grad - ans * x_grad.sum(dim=ctx.dim, keepdim=True)
return random_cast_to_half(x_grad), None
else:
x_grad = ans_grad * ans
x_grad = x_grad - ans * x_grad.sum(dim=ctx.dim, keepdim=True)
return x_grad, None
def softmax(x: Tensor,
dim: int):
return SoftmaxFunction.apply(x, dim)
class MaxEigLimiterFunction(torch.autograd.Function):
@staticmethod
@ -942,11 +993,24 @@ def _test_double_swish_deriv():
torch.autograd.gradcheck(m, x)
def _test_softmax():
a = torch.randn(2, 10, dtype=torch.float64)
b = a.clone()
a.requires_grad = True
b.requires_grad = True
a.softmax(dim=1)[:,0].sum().backward()
print("a grad = ", a.grad)
softmax(b, dim=1)[:,0].sum().backward()
print("b grad = ", b.grad)
assert torch.allclose(a.grad, b.grad)
if __name__ == "__main__":
logging.getLogger().setLevel(logging.INFO)
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
_test_softmax()
_test_whiten()
_test_max_eig()
_test_activation_balancer_sign()