mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-08 16:44:20 +00:00
Make it work for half
This commit is contained in:
parent
e535887abb
commit
d76aedb790
@ -667,7 +667,7 @@ class ProjDrop(torch.nn.Module):
|
||||
device=x.device, dtype=x.dtype)
|
||||
rr = torch.matmul(r, r.t()) # num_dropped by num_dropped
|
||||
rr += 0.01 # to 100% ensure it is invertible
|
||||
rr_inv = rr.cholesky().cholesky_inverse()
|
||||
rr_inv = rr.cholesky().to(torch.float32).cholesky_inverse().to(x.dtype)
|
||||
# OK, so r rr_inv r.t() will have eigenvalues of 1.
|
||||
xr = torch.matmul(x, r.t()) # (..., num_dropped)
|
||||
rr_inv_r = torch.matmul(rr_inv, r) # (num_dropped, num_channels)
|
||||
|
Loading…
x
Reference in New Issue
Block a user