diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py index cd3bc07c6..9015f3f3b 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py @@ -752,10 +752,9 @@ class Decorrelate(torch.nn.Module): self.eps = eps self.beta = beta - rand_mat = torch.randn(num_channels, num_channels) - U, _, _ = rand_mat.svd() - - self.register_buffer('U', U) # a random orthogonal square matrix. will be a buffer. + #rand_mat = torch.randn(num_channels, num_channels) + #U, _, _ = rand_mat.svd() + #self.register_buffer('U', U) # a random orthogonal square matrix. will be a buffer. self.register_buffer('T1', torch.eye(num_channels)) self.register_buffer('rand_scales', torch.zeros(num_channels)) @@ -887,7 +886,7 @@ class Decorrelate(torch.nn.Module): x_bypass = x - if True: + if False: # This block, in effect, multiplies x by a random orthogonal matrix, # giving us random noise. perm = self._randperm_like(x) @@ -901,6 +900,9 @@ class Decorrelate(torch.nn.Module): x_next.scatter_(-1, perm, x) x = x_next + mask = (torch.randn_like(x) > 0.5) + x = x - (x * mask) * 2 + x = (x * self.rand_scales) + (x_bypass * self.nonrand_scales) x = torch.matmul(x, self.T2) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless5/conformer.py index 138950d55..b64b9a6bc 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/conformer.py @@ -199,7 +199,7 @@ class ConformerEncoderLayer(nn.Module): ) self.dropout = torch.nn.Dropout(dropout) - self.decorrelate = Decorrelate(d_model, apply_prob=0.25, dropout_rate=0.1) + self.decorrelate = Decorrelate(d_model, apply_prob=0.25, dropout_rate=0.05) def forward(