Simplify implementation as current idea was not working to decorrelate

This commit is contained in:
Daniel Povey 2022-06-08 10:24:41 +08:00
parent 135be1e19c
commit a83bde1372
2 changed files with 8 additions and 6 deletions

View File

@ -752,10 +752,9 @@ class Decorrelate(torch.nn.Module):
self.eps = eps
self.beta = beta
rand_mat = torch.randn(num_channels, num_channels)
U, _, _ = rand_mat.svd()
self.register_buffer('U', U) # a random orthogonal square matrix. will be a buffer.
#rand_mat = torch.randn(num_channels, num_channels)
#U, _, _ = rand_mat.svd()
#self.register_buffer('U', U) # a random orthogonal square matrix. will be a buffer.
self.register_buffer('T1', torch.eye(num_channels))
self.register_buffer('rand_scales', torch.zeros(num_channels))
@ -887,7 +886,7 @@ class Decorrelate(torch.nn.Module):
x_bypass = x
if True:
if False:
# This block, in effect, multiplies x by a random orthogonal matrix,
# giving us random noise.
perm = self._randperm_like(x)
@ -901,6 +900,9 @@ class Decorrelate(torch.nn.Module):
x_next.scatter_(-1, perm, x)
x = x_next
mask = (torch.randn_like(x) > 0.5)
x = x - (x * mask) * 2
x = (x * self.rand_scales) + (x_bypass * self.nonrand_scales)
x = torch.matmul(x, self.T2)

View File

@ -199,7 +199,7 @@ class ConformerEncoderLayer(nn.Module):
)
self.dropout = torch.nn.Dropout(dropout)
self.decorrelate = Decorrelate(d_model, apply_prob=0.25, dropout_rate=0.1)
self.decorrelate = Decorrelate(d_model, apply_prob=0.25, dropout_rate=0.05)
def forward(