diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py index 09eadc9e9..068c4f77e 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py @@ -713,6 +713,104 @@ class GaussProjDrop(torch.nn.Module): x = (x_next * self.rand_scale + x_bypass) return x +class DecorrelateFunction(torch.autograd.Function): + # does nothing in forward pass. In backward pass it modifies + # the gradients in such a way as to encourage the dims of x + # to become uncorrelated, taken over all the stats. + @staticmethod + def forward(ctx, x: Tensor, old_cov: Tensor, + scale: float, eps: float, beta: float, + channel_dim: int) -> Tensor: + ctx.save_for_backward(x, old_cov) + ctx.scale = scale + ctx.eps = eps + ctx.beta = beta + ctx.channel_dim = channel_dim + return x + + @staticmethod + def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None, None, None]: + x, old_cov = ctx.saved_tensors + with torch.enable_grad(): + x = x.transpose(-1, ctx.channel_dim) + x_grad = x_grad.transpose(-1, ctx.channel_dim) + num_channels = x.shape[-1] + full_shape = x.shape + x = x.reshape(-1, num_channels) + x = x.detach() + x.requires_grad = True + x_grad = x_grad.reshape(-1, num_channels) + + cov = old_cov * ctx.beta + torch.matmul(x.t(), x) * (1-ctx.beta) + inv_sqrt_diag = (cov.diag() + ctx.eps) ** -0.5 + norm_cov = cov * (inv_sqrt_diag * inv_sqrt_diag.unsqueeze(-1)) + print("Decorrelate: norm_cov = ", norm_cov) + + loss = ((norm_cov - norm_cov.diag().diag()) ** 2).sum() / num_channels - 1 + if random.random() < 0.01: + logging.info(f"Decorrelate: loss = {loss}") + loss.backward() + x_grad_new = x.grad + assert x.grad is not None + + # Now, normalize the magnitudes of the rows of the new grad + # contribution, to have magnitudes equals to ctx.scale times + # `loss ** 0.5` times the magnitude of the original grad. + x_grad_new_scale = (x_grad_new ** 2).sum(dim=1) + x_grad_old_scale = (x_grad ** 2).sum(dim=1) + decorr_loss_scale = ctx.scale * loss.detach() ** 0.5 + scale = decorr_loss_scale * (x_grad_old_scale / (x_grad_new_scale + 1.0e-10)) ** 0.5 + x_grad_new = x_grad_new * scale.unsqueeze(-1) + + x_grad = x_grad + x_grad_new + # reshape.. + x_grad = x_grad.reshape(full_shape) + x_grad = x_grad.transpose(-1, ctx.channel_dim) + + return x_grad, None, None, None, None, None + + + +class Decorrelate(torch.nn.Module): + """ + This module does nothing in the forward pass, but in the backward pass, modifies + the derivatives in such a way as to encourage the dimensions of its input to become + decorrelated. + """ + def __init__(self, + num_channels: int, + scale: float = 0.1, + eps: float = 1.0e-05, + beta: float = 0.95, + channel_dim: int = -1): + super(Decorrelate, self).__init__() + self.scale = scale + self.eps = eps + self.beta = beta + self.channel_dim = channel_dim + + self.register_buffer('cov', torch.zeros(num_channels, num_channels)) + self.step = 0 + + + def forward(self, x: Tensor) -> Tensor: + if not self.training: + return x + else: + with torch.cuda.amp.autocast(enabled=False): + ans = DecorrelateFunction.apply(x, self.cov.clone(), + self.scale, self.eps, self.beta, + self.channel_dim) # == x. + + x = x.transpose(self.channel_dim, -1) + x = x.reshape(-1, x.shape[-1]) + cov = torch.matmul(x.t(), x) / x.shape[0] + self.cov.mul_(self.beta).add_(cov, alpha=(1-self.beta)) + self.step += 1 + + return ans # ans == x. + + class JoinDropout(torch.nn.Module): """ This module implements something like: @@ -735,7 +833,7 @@ class JoinDropout(torch.nn.Module): beta: A value 0 < beta < 1 that controls decay of covariance stats channel_dim: The dimension of the input corresponding to the channel, e.g. -1, 0, 1, 2. - """ + """ def __init__(self, num_channels: int, apply_prob: float = 0.75, @@ -955,9 +1053,19 @@ def _test_join_dropout(): x = torch.randn(30000, D) # give it a non-unit covariance. - m = torch.randn(D, D) + m = torch.randn(D, D) * (D ** -0.5) + _, S, _ = m.svd() + print("M eigs = ", S[::10]) x = torch.matmul(x, m) + + if True: + # check that class Decorrelate does not crash when running.. + decorrelate = Decorrelate(D) + x.requires_grad = True + y = decorrelate(x) + y.sum().backward() + for dropout_rate in [0.2, 0.1, 0.01, 0.05]: m1 = torch.nn.Dropout(dropout_rate) m2 = JoinDropout(D, apply_prob=1.0, dropout_rate=dropout_rate) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless5/conformer.py index 7440fe05a..03f7fe6f5 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/conformer.py @@ -19,7 +19,7 @@ import copy import math import warnings from typing import List, Optional, Tuple - +import logging import torch from encoder_interface import EncoderInterface from scaling import ( @@ -30,7 +30,7 @@ from scaling import ( ScaledConv2d, ScaledLinear, JoinDropout, - + Decorrelate, ) from torch import Tensor, nn @@ -1032,11 +1032,13 @@ class Conv2dSubsampling(nn.Module): # itself has learned scale, so the extra degree of freedom is not # needed. self.out_norm = BasicNorm(out_channels, learn_eps=False) + self.decorrelate = Decorrelate(out_channels) # constrain median of output to be close to zero. self.out_balancer = ActivationBalancer( channel_dim=-1, min_positive=0.45, max_positive=0.55 ) + def forward(self, x: torch.Tensor) -> torch.Tensor: """Subsample x. @@ -1055,6 +1057,7 @@ class Conv2dSubsampling(nn.Module): x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f)) # Now x is of shape (N, ((T-1)//2 - 1))//2, odim) x = self.out_norm(x) + x = self.decorrelate(x) x = self.out_balancer(x) return x