From d76aedb790581de2f2fd41bc6b663bf1c0012099 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 5 Jun 2022 23:25:51 +0800 Subject: [PATCH] Make it work for half --- egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py index cc832e1bc..598c621cd 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py @@ -667,7 +667,7 @@ class ProjDrop(torch.nn.Module): device=x.device, dtype=x.dtype) rr = torch.matmul(r, r.t()) # num_dropped by num_dropped rr += 0.01 # to 100% ensure it is invertible - rr_inv = rr.cholesky().cholesky_inverse() + rr_inv = rr.cholesky().to(torch.float32).cholesky_inverse().to(x.dtype) # OK, so r rr_inv r.t() will have eigenvalues of 1. xr = torch.matmul(x, r.t()) # (..., num_dropped) rr_inv_r = torch.matmul(rr_inv, r) # (num_dropped, num_channels)