Bug fix RE GPU device

This commit is contained in:
Daniel Povey 2022-06-06 15:40:20 +08:00
parent 71e927411a
commit 6fdb356315

View File

@ -657,7 +657,7 @@ class FixedProjDrop(torch.nn.Module):
rand_mat = torch.randn(num_channels, num_channels)
U, _, _ = rand_mat.svd()
self.U = U # a random orthogonal square matrix. will be a buffer.
self.register_buffer('U', U) # a random orthogonal square matrix. will be a buffer.
def forward(self, x: Tensor) -> Tensor: