diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py index 3403c291b..8e18ee4b5 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py @@ -713,6 +713,25 @@ class GaussProjDrop(torch.nn.Module): x = (x_next * self.rand_scale + x_bypass) return x +class PseudoNormalizeFunction(torch.autograd.Function): + """ + Function object that is the identity function in the forward pass; and, in the + backward pass, removes the component of the derivative in the direction of x itself + (as if it had gone through some kind of normalization layer + """ + @staticmethod + def forward(ctx, x: Tensor) -> Tensor: + ctx.save_for_backward(x) + return x + + @staticmethod + def backward(ctx, x_grad: Tensor) -> Tensor: + x, = ctx.saved_tensors + eps = 1.0e-20 + x_sumsq = (x**2).sum() + eps + grad_x_sum = (x_grad * x).sum() + return x_grad - x * (grad_x_sum / x_sumsq) + def _compute_correlation_loss(cov: Tensor, eps: float) -> Tensor: @@ -740,7 +759,9 @@ def _update_cov_stats(cov: Tensor, x: Tensor of features/activations, of shape (num_frames, num_channels) beta: The decay constant for the stats, e.g. 0.8. """ - return cov * beta + torch.matmul(x.t(), x) * (1-beta) + new_cov = torch.matmul(x.t(), x) + new_cov = PseudoNormalizeFunction.apply(new_cov) + return cov * beta + new_cov * (1-beta) class DecorrelateFunction(torch.autograd.Function): @@ -889,6 +910,8 @@ class Decorrelate(torch.nn.Module): cov = torch.matmul(x.t(), x) with torch.no_grad(): self.cov.mul_(self.beta).add_(cov, alpha=(1-self.beta)) + m = self.cov.max() + assert m == m return ans # ans == x. @@ -987,6 +1010,14 @@ def _test_gauss_proj_drop(): m1.eval() m2.eval() +def _test_pseudo_normalize(): + x = torch.randn(3, 4) + x.requires_grad = True + y = PseudoNormalizeFunction.apply(x) + l = y.sin().sum() + l.backward() + assert (x.grad * x).sum().abs() < 0.1 + def _test_decorrelate(): D = 384 x = torch.randn(30000, D) @@ -1014,6 +1045,7 @@ if __name__ == "__main__": logging.getLogger().setLevel(logging.INFO) torch.set_num_threads(1) torch.set_num_interop_threads(1) + _test_pseudo_normalize() _test_decorrelate() _test_gauss_proj_drop() _test_activation_balancer_sign() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless5/conformer.py index e1b1939fb..2babb94bc 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/conformer.py @@ -94,7 +94,7 @@ class Conformer(EncoderInterface): aux_layers=list(range(0, num_encoder_layers - 1, aux_layer_period)), ) - self.decorrelate = Decorrelate(d_model, scale=0.1) + self.decorrelate = Decorrelate(d_model, scale=0.05) def forward(