diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py index 107d22671..acc0defa6 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py @@ -400,8 +400,6 @@ class LinearWithAuxLossFunction(torch.autograd.Function): x = x.to(torch.float16) ctx.save_for_backward(x, weight, alpha) ctx.aux_grad_scale = aux_grad_scale - if torch.is_autocast_enabled(): - weight = weight.to(torch.float16) return torch.matmul(x, weight.t())