mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Store only half precision output for softmax.
This commit is contained in:
parent
d3876e32c4
commit
95aaa4a8d2
@ -36,6 +36,7 @@ from scaling import (
|
||||
_diag,
|
||||
random_clamp,
|
||||
penalize_abs_values_gt,
|
||||
softmax,
|
||||
)
|
||||
from torch import Tensor, nn
|
||||
|
||||
@ -1161,7 +1162,12 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
bsz * num_heads, seq_len, seq_len
|
||||
)
|
||||
|
||||
attn_output_weights = attn_output_weights.softmax(dim=-1)
|
||||
# Using this version of softmax, defined in scaling.py,
|
||||
# should save a little of the memory used in backprop by, if
|
||||
# we are in automatic mixed precision mode (amp) == autocast,
|
||||
# only storing the half-precision output for backprop purposes.
|
||||
attn_output_weights = softmax(attn_output_weights, dim=-1)
|
||||
|
||||
attn_output_weights = nn.functional.dropout(
|
||||
attn_output_weights, p=dropout_p, training=training
|
||||
)
|
||||
|
||||
@ -260,6 +260,8 @@ class SoftmaxFunction(torch.autograd.Function):
|
||||
# if x dtype is float16, x.softmax() returns a float32 because
|
||||
# (presumably) that op does not support float16, and autocast
|
||||
# is enabled.
|
||||
if torch.is_autocast_enabled():
|
||||
ans = ans.to(torch.float16)
|
||||
ctx.save_for_backward(ans)
|
||||
ctx.x_dtype = x.dtype
|
||||
ctx.dim = dim
|
||||
@ -273,9 +275,6 @@ class SoftmaxFunction(torch.autograd.Function):
|
||||
ans = ans.to(torch.float32)
|
||||
x_grad = ans_grad * ans
|
||||
x_grad = x_grad - ans * x_grad.sum(dim=ctx.dim, keepdim=True)
|
||||
if ctx.x_dtype == torch.float16:
|
||||
x_grad = random_cast_to_half(x_grad)
|
||||
|
||||
return x_grad, None
|
||||
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user