Fix bug that relates to modifying U in place

This commit is contained in:
Daniel Povey 2022-06-06 17:43:15 +08:00
parent 31848dcd11
commit 4352a16f57

View File

@ -668,11 +668,11 @@ class RandProjDrop(torch.nn.Module):
else: else:
self._randomize_U() self._randomize_U()
x = x.transpose(self.channel_dim, -1) # (..., num_channels) x = x.transpose(self.channel_dim, -1) # (..., num_channels)
U = self.U.clone()
x = torch.matmul(x, self.U) x = torch.matmul(x, U)
x = torch.nn.functional.dropout(x, self.dropout_rate, x = torch.nn.functional.dropout(x, self.dropout_rate,
training=True) training=True)
x = torch.matmul(x, self.U.t()) x = torch.matmul(x, U.t())
x = x.transpose(self.channel_dim, -1) # (..., num_channels) x = x.transpose(self.channel_dim, -1) # (..., num_channels)
return x return x