From bfeeddda81577f2c3e5086bb7e40ac201962c562 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 16 May 2023 12:18:09 +0800 Subject: [PATCH] Reduce mem consumption of softmax backward --- egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py index 603110d95..93257fdb6 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py @@ -276,7 +276,8 @@ class SoftmaxFunction(torch.autograd.Function): ans_grad = ans_grad.to(torch.float32) ans = ans.to(torch.float32) x_grad = ans_grad * ans - x_grad = x_grad - ans * x_grad.sum(dim=ctx.dim, keepdim=True) + ans *= x_grad.sum(dim=ctx.dim, keepdim=True) + x_grad -= ans return x_grad, None