diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py index cc832e1bc..598c621cd 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py @@ -667,7 +667,7 @@ class ProjDrop(torch.nn.Module): device=x.device, dtype=x.dtype) rr = torch.matmul(r, r.t()) # num_dropped by num_dropped rr += 0.01 # to 100% ensure it is invertible - rr_inv = rr.cholesky().cholesky_inverse() + rr_inv = rr.cholesky().to(torch.float32).cholesky_inverse().to(x.dtype) # OK, so r rr_inv r.t() will have eigenvalues of 1. xr = torch.matmul(x, r.t()) # (..., num_dropped) rr_inv_r = torch.matmul(rr_inv, r) # (num_dropped, num_channels)