Fix bug re half precision

This commit is contained in:
Daniel Povey 2022-06-05 23:26:59 +08:00
parent d76aedb790
commit 28df3ba43f

View File

@ -667,7 +667,7 @@ class ProjDrop(torch.nn.Module):
device=x.device, dtype=x.dtype)
rr = torch.matmul(r, r.t()) # num_dropped by num_dropped
rr += 0.01 # to 100% ensure it is invertible
rr_inv = rr.cholesky().to(torch.float32).cholesky_inverse().to(x.dtype)
rr_inv = rr.to(torch.float32).cholesky().cholesky_inverse().to(x.dtype)
# OK, so r rr_inv r.t() will have eigenvalues of 1.
xr = torch.matmul(x, r.t()) # (..., num_dropped)
rr_inv_r = torch.matmul(rr_inv, r) # (num_dropped, num_channels)