diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py index 23afb4387..a3b8064b5 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py @@ -255,10 +255,13 @@ class SoftmaxFunction(torch.autograd.Function): """ @staticmethod def forward(ctx, x: Tensor, dim: int): - ans = x.softmax(dim=dim) - ctx.save_for_backward(ans) - ctx.dim = dim - return ans + with torch.cuda.amp.autocast(enabled=False): + if x.dtype == torch.float16: + x = x.to(torch.float32) + ans = x.softmax(dim=dim) + ctx.save_for_backward(ans) + ctx.dim = dim + return ans @staticmethod def backward(ctx, ans_grad: Tensor):