From 28df3ba43f3ca44e0e85e8110cf3526d7229f450 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 5 Jun 2022 23:26:59 +0800 Subject: [PATCH] Fix bug re half precision --- egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py index 598c621cd..78ee5da7b 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py @@ -667,7 +667,7 @@ class ProjDrop(torch.nn.Module): device=x.device, dtype=x.dtype) rr = torch.matmul(r, r.t()) # num_dropped by num_dropped rr += 0.01 # to 100% ensure it is invertible - rr_inv = rr.cholesky().to(torch.float32).cholesky_inverse().to(x.dtype) + rr_inv = rr.to(torch.float32).cholesky().cholesky_inverse().to(x.dtype) # OK, so r rr_inv r.t() will have eigenvalues of 1. xr = torch.matmul(x, r.t()) # (..., num_dropped) rr_inv_r = torch.matmul(rr_inv, r) # (num_dropped, num_channels)