Store only half precision output for softmax.

This commit is contained in:
Daniel Povey 2022-10-23 21:24:46 +08:00
parent d3876e32c4
commit 95aaa4a8d2
2 changed files with 9 additions and 4 deletions

View File

@ -36,6 +36,7 @@ from scaling import (
_diag,
random_clamp,
penalize_abs_values_gt,
softmax,
)
from torch import Tensor, nn
@ -1161,7 +1162,12 @@ class RelPositionMultiheadAttention(nn.Module):
bsz * num_heads, seq_len, seq_len
)
attn_output_weights = attn_output_weights.softmax(dim=-1)
# Using this version of softmax, defined in scaling.py,
# should save a little of the memory used in backprop by, if
# we are in automatic mixed precision mode (amp) == autocast,
# only storing the half-precision output for backprop purposes.
attn_output_weights = softmax(attn_output_weights, dim=-1)
attn_output_weights = nn.functional.dropout(
attn_output_weights, p=dropout_p, training=training
)

View File

@ -260,6 +260,8 @@ class SoftmaxFunction(torch.autograd.Function):
# if x dtype is float16, x.softmax() returns a float32 because
# (presumably) that op does not support float16, and autocast
# is enabled.
if torch.is_autocast_enabled():
ans = ans.to(torch.float16)
ctx.save_for_backward(ans)
ctx.x_dtype = x.dtype
ctx.dim = dim
@ -273,9 +275,6 @@ class SoftmaxFunction(torch.autograd.Function):
ans = ans.to(torch.float32)
x_grad = ans_grad * ans
x_grad = x_grad - ans * x_grad.sum(dim=ctx.dim, keepdim=True)
if ctx.x_dtype == torch.float16:
x_grad = random_cast_to_half(x_grad)
return x_grad, None