From dbfbd8016bbb1a77712987645a9752767da0779a Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 25 Oct 2022 13:16:00 +0800 Subject: [PATCH] Cast to float16 in DoubleSwish forward --- egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py index e93b8dded..c569fafad 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py @@ -939,6 +939,8 @@ class DoubleSwishFunction(torch.autograd.Function): assert d_scaled.max() < 256.0 d_int = d_scaled.to(torch.uint8) ctx.save_for_backward(d_int) + if x.dtype == torch.float16 or torch.is_autocast_enabled(): + y = y.to(torch.float16) return y @staticmethod