From d1ee1f2d986252dee42ce9989e0434d769fea9d1 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 26 Nov 2022 14:25:27 +0800 Subject: [PATCH] Try to save memory in autocast mode. --- egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py index 84c408c12..107d22671 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py @@ -396,8 +396,12 @@ class LinearWithAuxLossFunction(torch.autograd.Function): In the backward pass it will include an auxiliary loss based on predicting x from matmul(y, weight). """ + if torch.is_autocast_enabled(): + x = x.to(torch.float16) ctx.save_for_backward(x, weight, alpha) ctx.aux_grad_scale = aux_grad_scale + if torch.is_autocast_enabled(): + weight = weight.to(torch.float16) return torch.matmul(x, weight.t())