From 0a997d64c4544406d16cfe75e0f9729aab62728c Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 25 Nov 2022 16:07:47 +0800 Subject: [PATCH] Fixes for half precision --- egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py index 0fcd73878..7dbc2b1d3 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py @@ -423,9 +423,9 @@ class LinearWithAuxLossFunction(torch.autograd.Function): weight_aux_grad = weight.grad alpha_grad = alpha.grad - x_grad = torch.matmul(ans_grad, weight) + x_grad = torch.matmul(ans_grad, weight.to(ans_grad.dtype)) weight_grad = torch.matmul(ans_grad.reshape(-1, ans_grad.shape[-1]).t(), - x.reshape(-1, x.shape[-1])) + x.reshape(-1, x.shape[-1]).to(ans_grad.dtype)) with torch.cuda.amp.autocast(enabled=False): weight_grad_norm = weight_grad.to(torch.float32).norm()