diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py index 84c408c12..107d22671 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py @@ -396,8 +396,12 @@ class LinearWithAuxLossFunction(torch.autograd.Function): In the backward pass it will include an auxiliary loss based on predicting x from matmul(y, weight). """ + if torch.is_autocast_enabled(): + x = x.to(torch.float16) ctx.save_for_backward(x, weight, alpha) ctx.aux_grad_scale = aux_grad_scale + if torch.is_autocast_enabled(): + weight = weight.to(torch.float16) return torch.matmul(x, weight.t())