From 95aaa4a8d2eaa73bccdcaf31984789f2c0a9baa7 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 23 Oct 2022 21:24:46 +0800 Subject: [PATCH] Store only half precision output for softmax. --- .../ASR/pruned_transducer_stateless7/conformer.py | 8 +++++++- .../ASR/pruned_transducer_stateless7/scaling.py | 5 ++--- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py index 8734f266e..efed5eb9d 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py @@ -36,6 +36,7 @@ from scaling import ( _diag, random_clamp, penalize_abs_values_gt, + softmax, ) from torch import Tensor, nn @@ -1161,7 +1162,12 @@ class RelPositionMultiheadAttention(nn.Module): bsz * num_heads, seq_len, seq_len ) - attn_output_weights = attn_output_weights.softmax(dim=-1) + # Using this version of softmax, defined in scaling.py, + # should save a little of the memory used in backprop by, if + # we are in automatic mixed precision mode (amp) == autocast, + # only storing the half-precision output for backprop purposes. + attn_output_weights = softmax(attn_output_weights, dim=-1) + attn_output_weights = nn.functional.dropout( attn_output_weights, p=dropout_p, training=training ) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py index 7e1b9a822..f741d853c 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py @@ -260,6 +260,8 @@ class SoftmaxFunction(torch.autograd.Function): # if x dtype is float16, x.softmax() returns a float32 because # (presumably) that op does not support float16, and autocast # is enabled. + if torch.is_autocast_enabled(): + ans = ans.to(torch.float16) ctx.save_for_backward(ans) ctx.x_dtype = x.dtype ctx.dim = dim @@ -273,9 +275,6 @@ class SoftmaxFunction(torch.autograd.Function): ans = ans.to(torch.float32) x_grad = ans_grad * ans x_grad = x_grad - ans * x_grad.sum(dim=ctx.dim, keepdim=True) - if ctx.x_dtype == torch.float16: - x_grad = random_cast_to_half(x_grad) - return x_grad, None