diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py index 598c621cd..78ee5da7b 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py @@ -667,7 +667,7 @@ class ProjDrop(torch.nn.Module): device=x.device, dtype=x.dtype) rr = torch.matmul(r, r.t()) # num_dropped by num_dropped rr += 0.01 # to 100% ensure it is invertible - rr_inv = rr.cholesky().to(torch.float32).cholesky_inverse().to(x.dtype) + rr_inv = rr.to(torch.float32).cholesky().cholesky_inverse().to(x.dtype) # OK, so r rr_inv r.t() will have eigenvalues of 1. xr = torch.matmul(x, r.t()) # (..., num_dropped) rr_inv_r = torch.matmul(rr_inv, r) # (num_dropped, num_channels)