mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-08 08:34:19 +00:00
Simplify implementation as current idea was not working to decorrelate
This commit is contained in:
parent
135be1e19c
commit
a83bde1372
@ -752,10 +752,9 @@ class Decorrelate(torch.nn.Module):
|
||||
self.eps = eps
|
||||
self.beta = beta
|
||||
|
||||
rand_mat = torch.randn(num_channels, num_channels)
|
||||
U, _, _ = rand_mat.svd()
|
||||
|
||||
self.register_buffer('U', U) # a random orthogonal square matrix. will be a buffer.
|
||||
#rand_mat = torch.randn(num_channels, num_channels)
|
||||
#U, _, _ = rand_mat.svd()
|
||||
#self.register_buffer('U', U) # a random orthogonal square matrix. will be a buffer.
|
||||
|
||||
self.register_buffer('T1', torch.eye(num_channels))
|
||||
self.register_buffer('rand_scales', torch.zeros(num_channels))
|
||||
@ -887,7 +886,7 @@ class Decorrelate(torch.nn.Module):
|
||||
|
||||
x_bypass = x
|
||||
|
||||
if True:
|
||||
if False:
|
||||
# This block, in effect, multiplies x by a random orthogonal matrix,
|
||||
# giving us random noise.
|
||||
perm = self._randperm_like(x)
|
||||
@ -901,6 +900,9 @@ class Decorrelate(torch.nn.Module):
|
||||
x_next.scatter_(-1, perm, x)
|
||||
x = x_next
|
||||
|
||||
mask = (torch.randn_like(x) > 0.5)
|
||||
x = x - (x * mask) * 2
|
||||
|
||||
x = (x * self.rand_scales) + (x_bypass * self.nonrand_scales)
|
||||
|
||||
x = torch.matmul(x, self.T2)
|
||||
|
@ -199,7 +199,7 @@ class ConformerEncoderLayer(nn.Module):
|
||||
)
|
||||
|
||||
self.dropout = torch.nn.Dropout(dropout)
|
||||
self.decorrelate = Decorrelate(d_model, apply_prob=0.25, dropout_rate=0.1)
|
||||
self.decorrelate = Decorrelate(d_model, apply_prob=0.25, dropout_rate=0.05)
|
||||
|
||||
|
||||
def forward(
|
||||
|
Loading…
x
Reference in New Issue
Block a user