Store only half precision output for softmax.
This commit is contained in:
parent
d3876e32c4
commit
95aaa4a8d2
@ -36,6 +36,7 @@ from scaling import (
|
|||||||
_diag,
|
_diag,
|
||||||
random_clamp,
|
random_clamp,
|
||||||
penalize_abs_values_gt,
|
penalize_abs_values_gt,
|
||||||
|
softmax,
|
||||||
)
|
)
|
||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
|
|
||||||
@ -1161,7 +1162,12 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
bsz * num_heads, seq_len, seq_len
|
bsz * num_heads, seq_len, seq_len
|
||||||
)
|
)
|
||||||
|
|
||||||
attn_output_weights = attn_output_weights.softmax(dim=-1)
|
# Using this version of softmax, defined in scaling.py,
|
||||||
|
# should save a little of the memory used in backprop by, if
|
||||||
|
# we are in automatic mixed precision mode (amp) == autocast,
|
||||||
|
# only storing the half-precision output for backprop purposes.
|
||||||
|
attn_output_weights = softmax(attn_output_weights, dim=-1)
|
||||||
|
|
||||||
attn_output_weights = nn.functional.dropout(
|
attn_output_weights = nn.functional.dropout(
|
||||||
attn_output_weights, p=dropout_p, training=training
|
attn_output_weights, p=dropout_p, training=training
|
||||||
)
|
)
|
||||||
|
|||||||
@ -260,6 +260,8 @@ class SoftmaxFunction(torch.autograd.Function):
|
|||||||
# if x dtype is float16, x.softmax() returns a float32 because
|
# if x dtype is float16, x.softmax() returns a float32 because
|
||||||
# (presumably) that op does not support float16, and autocast
|
# (presumably) that op does not support float16, and autocast
|
||||||
# is enabled.
|
# is enabled.
|
||||||
|
if torch.is_autocast_enabled():
|
||||||
|
ans = ans.to(torch.float16)
|
||||||
ctx.save_for_backward(ans)
|
ctx.save_for_backward(ans)
|
||||||
ctx.x_dtype = x.dtype
|
ctx.x_dtype = x.dtype
|
||||||
ctx.dim = dim
|
ctx.dim = dim
|
||||||
@ -273,9 +275,6 @@ class SoftmaxFunction(torch.autograd.Function):
|
|||||||
ans = ans.to(torch.float32)
|
ans = ans.to(torch.float32)
|
||||||
x_grad = ans_grad * ans
|
x_grad = ans_grad * ans
|
||||||
x_grad = x_grad - ans * x_grad.sum(dim=ctx.dim, keepdim=True)
|
x_grad = x_grad - ans * x_grad.sum(dim=ctx.dim, keepdim=True)
|
||||||
if ctx.x_dtype == torch.float16:
|
|
||||||
x_grad = random_cast_to_half(x_grad)
|
|
||||||
|
|
||||||
return x_grad, None
|
return x_grad, None
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user