From 1ebc3dd15810387796ca6ec5024e433823f1a02a Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 25 Nov 2022 16:20:28 +0800 Subject: [PATCH] Bug fixes to LinearWithAuxLoss --- .../ASR/pruned_transducer_stateless7/scaling.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py index 7dbc2b1d3..eb556b4e9 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py @@ -404,8 +404,15 @@ class LinearWithAuxLossFunction(torch.autograd.Function): @staticmethod def backward(ctx, ans_grad: Tensor) -> Tuple[Tensor, Tensor, Tensor, None]: x, weight, alpha = ctx.saved_tensors + + x_grad = torch.matmul(ans_grad, weight.to(ans_grad.dtype)) + weight_grad = torch.matmul(ans_grad.reshape(-1, ans_grad.shape[-1]).t(), + x.reshape(-1, x.shape[-1]).to(ans_grad.dtype)) + + with torch.cuda.amp.autocast(enabled=False): with torch.enable_grad(): + x = x.to(weight.dtype) x, weight, alpha = x.detach(), weight.detach(), alpha.detach() weight.requires_grad = True alpha.requires_grad = True @@ -423,9 +430,6 @@ class LinearWithAuxLossFunction(torch.autograd.Function): weight_aux_grad = weight.grad alpha_grad = alpha.grad - x_grad = torch.matmul(ans_grad, weight.to(ans_grad.dtype)) - weight_grad = torch.matmul(ans_grad.reshape(-1, ans_grad.shape[-1]).t(), - x.reshape(-1, x.shape[-1]).to(ans_grad.dtype)) with torch.cuda.amp.autocast(enabled=False): weight_grad_norm = weight_grad.to(torch.float32).norm()