This commit is contained in:
Daniel Povey 2022-06-08 11:05:29 +08:00
parent a83bde1372
commit e7886d49a9

View File

@ -900,7 +900,7 @@ class Decorrelate(torch.nn.Module):
x_next.scatter_(-1, perm, x)
x = x_next
mask = (torch.randn_like(x) > 0.5)
mask = (torch.rand_like(x) > 0.5)
x = x - (x * mask) * 2
x = (x * self.rand_scales) + (x_bypass * self.nonrand_scales)