From 0fd2cb141f2b3f8d13ffc76ff0b68084a9a94925 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 9 Jun 2022 22:54:01 +0800 Subject: [PATCH] Code cleanup and refactoring --- .../pruned_transducer_stateless2/scaling.py | 292 ++++++------------ .../pruned_transducer_stateless5/conformer.py | 1 - 2 files changed, 86 insertions(+), 207 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py index afdea6691..c0a0cb2c2 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py @@ -713,15 +713,65 @@ class GaussProjDrop(torch.nn.Module): x = (x_next * self.rand_scale + x_bypass) return x + +def _compute_correlation_loss(cov: Tensor, + eps: float) -> Tensor: + """ + Computes the correlation `loss`, which would be zero if the channel dimensions + are un-correlated, and equals num_channels if they are maximally correlated + (i.e., they all point in the same direction) + Args: + cov: Uncentered covariance matrix of shape (num_channels, num_channels), + does not have to be normalized by count. + """ + inv_sqrt_diag = (cov.diag() + eps) ** -0.5 + norm_cov = cov * (inv_sqrt_diag * inv_sqrt_diag.unsqueeze(-1)) + num_channels = cov.shape[0] + loss = ((norm_cov - norm_cov.diag().diag()) ** 2).sum() / num_channels + return loss + +def _update_cov_stats(cov: Tensor, + x: Tensor, + beta: float) -> Tensor: + """ + Updates covariance stats as a decaying sum, returning the result. + cov: Old covariance stats, to be added to and returned, of shape + (num_channels, num_channels) + x: Tensor of features/activations, of shape (num_frames, num_channels) + beta: The decay constant for the stats, e.g. 0.8. + """ + return cov * beta + torch.matmul(x.t(), x) * (1-beta) + + class DecorrelateFunction(torch.autograd.Function): - # does nothing in forward pass. In backward pass it modifies - # the gradients in such a way as to encourage the dims of x - # to become uncorrelated, taken over all the stats. + """ + Function object for a function that does nothing in the forward pass; + but, in the backward pass, adds derivatives that encourage the channel dimensions + to be un-correlated with each other. + This should not be used in a half-precision-enabled area, use + with torch.cuda.amp.autocast(enabled=False). + + Args: + x: The input tensor, which is also the function output. It can have + arbitrary shape, but its dimension `channel_dim` (e.g. -1) will be + interpreted as the channel dimension. + old_cov: Covariance statistics from previous frames, accumulated as a decaying + sum over frames (not average), decaying by beta each time. + scale: The scale on the derivatives arising from this, expressed as + a fraction of the norm of the derivative at the output (applied per + frame). We will further scale down the derivative if the normalized + loss is less than 1. + eps: Epsilon value to prevent division by zero when estimating diagonal + covariance. + beta: Decay constant that determines how we combine stats from x with + stats from cov; e.g. 0.8. + channel_dim: The dimension/axis of x corresponding to the channel, e.g. 0, 1, 2, -1. + """ @staticmethod def forward(ctx, x: Tensor, old_cov: Tensor, scale: float, eps: float, beta: float, channel_dim: int) -> Tensor: - ctx.save_for_backward(x, old_cov) + ctx.save_for_backward(x.detach(), old_cov.detach()) ctx.scale = scale ctx.eps = eps ctx.beta = beta @@ -731,45 +781,48 @@ class DecorrelateFunction(torch.autograd.Function): @staticmethod def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None, None, None]: x, old_cov = ctx.saved_tensors + + + # Reshape x and x_grad to be (num_frames, num_channels) + x = x.transpose(-1, ctx.channel_dim) + x_grad = x_grad.transpose(-1, ctx.channel_dim) + num_channels = x.shape[-1] + full_shape = x.shape + x = x.reshape(-1, num_channels) + x_grad = x_grad.reshape(-1, num_channels) + x.requires_grad = True + with torch.enable_grad(): - x = x.transpose(-1, ctx.channel_dim) - x_grad = x_grad.transpose(-1, ctx.channel_dim) - num_channels = x.shape[-1] - full_shape = x.shape - x = x.reshape(-1, num_channels) - x = x.detach() - old_cov = old_cov.detach() - x.requires_grad = True - x_grad = x_grad.reshape(-1, num_channels) + cov = _update_cov_stats(old_cov, x, ctx.beta) + loss = _compute_correlation_loss(cov, ctx.eps) - cov = old_cov * ctx.beta + torch.matmul(x.t(), x) * (1-ctx.beta) - inv_sqrt_diag = (cov.diag() + ctx.eps) ** -0.5 - norm_cov = cov * (inv_sqrt_diag * inv_sqrt_diag.unsqueeze(-1)) - - loss = ((norm_cov - norm_cov.diag().diag()) ** 2).sum() / num_channels if random.random() < 0.01: logging.info(f"Decorrelate: loss = {loss}") loss.backward() - x_grad_new = x.grad - assert x.grad is not None - # Now, normalize the magnitudes of the rows of the new grad - # contribution, to have magnitudes equals to ctx.scale times - # `loss ** 0.5` times the magnitude of the original grad. - x_grad_new_scale = (x_grad_new ** 2).sum(dim=1) - x_grad_old_scale = (x_grad ** 2).sum(dim=1) + decorr_x_grad = x.grad + assert x.grad is not None - decorr_loss_scale = ctx.scale * loss.detach().clamp(min=0.0, max=1.0) + # Now, normalize the magnitudes of the rows of the new grad + # contribution, to have magnitudes equals to ctx.scale times + # `loss ** 0.5` times the magnitude of the original grad. + decorr_x_grad_sqnorm = (decorr_x_grad ** 2).sum(dim=1) + x_grad_old_sqnorm = (x_grad ** 2).sum(dim=1) - scale = decorr_loss_scale * (x_grad_old_scale / (x_grad_new_scale + 1.0e-10)) ** 0.5 - x_grad_new = x_grad_new * scale.unsqueeze(-1) + # loss.detach().clamp(min=0.0, max=1.0) is a factor that means once + # the loss starts getting quite small (less than 1), we start using + # smaller derivatives. + decorr_loss_scale = ctx.scale * loss.detach().clamp(min=0.0, max=1.0) + scale = decorr_loss_scale * (x_grad_old_sqnorm / (decorr_x_grad_sqnorm + 1.0e-10)) ** 0.5 + decorr_x_grad = decorr_x_grad * scale.unsqueeze(-1) - x_grad = x_grad + x_grad_new - # reshape.. - x_grad = x_grad.reshape(full_shape) - x_grad = x_grad.transpose(-1, ctx.channel_dim) + x_grad = x_grad + decorr_x_grad - return x_grad, None, None, None, None, None + # reshape back to original shape + x_grad = x_grad.reshape(full_shape) + x_grad = x_grad.transpose(-1, ctx.channel_dim) + + return x_grad, None, None, None, None, None @@ -813,7 +866,6 @@ class Decorrelate(torch.nn.Module): self.step = 0 - def load_state_dict(self, *args, **kwargs): super(Decorrelate, self).load_state_dict(*args, **kwargs) self.step = int(self.step_buf.item()) @@ -842,149 +894,6 @@ class Decorrelate(torch.nn.Module): return ans # ans == x. -class JoinDropout(torch.nn.Module): - """ - This module implements something like: - y = bypass + dropout(x) - but does it in such a way as to encourage x to vary in directions that will tend - to make the dimensions of y as decorrelated as possible. We do this - by putting lots of dropout in directions in the space in which we - don't want x to vary (because it will tend to increase correlations between - dimensions in the output y). - - - Args: - num_channels: The number of channels, e.g. 256. - apply_prob_decay: The probability with which we apply this each time, in - training mode, will decay as apply_prob_decay/(apply_prob_decay + step). - dropout_rate: This number determines the average dropout probability - (it will actually vary across dimensions). - eps: An epsilon used to prevent division by zero. - beta: A value 0 < beta < 1 that controls decay of covariance stats - channel_dim: The dimension of the input corresponding to the channel, e.g. - -1, 0, 1, 2. - """ - def __init__(self, - num_channels: int, - apply_prob: float = 0.75, - dropout_rate: float = 0.1, - eps: float = 1.0e-04, - beta: float = 0.95, - channel_dim: int = -1): - super(JoinDropout, self).__init__() - self.apply_prob = apply_prob - self.dropout_rate = dropout_rate - self.channel_dim = channel_dim - self.eps = eps - self.beta = beta - - self.register_buffer('T1', torch.eye(num_channels)) - self.register_buffer('dropout_probs', torch.zeros(num_channels)) - self.register_buffer('scales', torch.ones(num_channels)) - self.register_buffer('T2', torch.eye(num_channels)) - self.register_buffer('cov', torch.zeros(num_channels, num_channels)) - self.step = 0 - - - def _update_covar_stats(self, y: Tensor) -> None: - """ - Args: - y: Tensor of shape (*, num_channels), of output. - Updates covariance stats self.cov - """ - y = y.detach() - y = y.reshape(-1, y.shape[-1]) - y = y * (y.shape[0] ** -0.5) # avoid overflow in half precision - cov = torch.matmul(y.t(), y) - self.cov.mul_(self.beta).add_(cov, alpha=(1-self.beta)) - - def _update_transforms(self): - norm_cov, inv_sqrt_diag = self._normalize_covar(self.cov) - - U, S, _ = norm_cov.svd() # because diag of norm_cov is 1.0, S.mean() == 1.0 - - if random.random() < 0.1: - logging.info(f"JoinDropout: max,min eig of normalized cov is: {S.max().item():.2e},{S.min().item():.2e}") - - dropout_probs = (S.sqrt() - 0.99).clamp(min=0) - dropout_probs = dropout_probs * (self.dropout_rate / dropout_probs.mean()) - dropout_probs = dropout_probs.clamp(max=0.5) - self.dropout_probs[:] = dropout_probs - self.scales[:] = 1.0 / (1 - dropout_probs) - - - # row indexes of U correspond to channels, column indexes correspond to - # singular values: cov = U * diag(S) * U.t() where * is matmul. - - - # Transform T1, which we'll incorporate as torch.matmul(x, self.T1), is: - # (i) multiply by inv_sqrt_diag which makes the covariance have - # a unit diagonal. - # (ii) multiply by U, which diagonalizes norm_cov (uncorrelated channels) - self.T1[:] = (inv_sqrt_diag.unsqueeze(-1) * U) - - # Transform T2, which we'll incorporate as torch.matmul(x, self.T2), is: - # (i) multiply by U, which un-diagonalizes norm_cov - # (ii) divide by inv_sqrt_diag which makes the covariance have its original - # diagonal values. - self.T2[:] = (U.t() / inv_sqrt_diag) - - - if random.random() < 0.01: - d = torch.matmul(self.T1, self.T2) - torch.eye(self.T1.shape[0], - device=self.T1.device, - dtype=self.T1.dtype) - assert torch.all(d.abs() < 0.01) - - - - def _normalize_covar(self, cov: Tensor) -> Tensor: - """ - Normlizes a covariance matrix so that its diagonal is 1, by multiplying by - its diagonal**-0.5 on both sides. - Args: - cov: matrix to normalize - Returns normalized_cov, inv_sqrt_diag - """ - diag = cov.diag() - inv_sqrt_diag = (diag + self.eps) ** -0.5 - cov = cov * (inv_sqrt_diag * inv_sqrt_diag.unsqueeze(-1)) - return cov, inv_sqrt_diag - - - - def forward(self, bypass: Tensor, x: Tensor) -> Tensor: - apply_prob = self.apply_prob - if not self.training or random.random() > apply_prob: - return bypass + x - else: - x = x.transpose(self.channel_dim, -1) # (..., num_channels) - bypass = bypass.transpose(self.channel_dim, -1) - - x = torch.matmul(x, self.T1.clone()) - - mask = (torch.rand_like(x) > self.dropout_probs) - x = (x * mask) * self.scales.clone() - x = torch.matmul(x, self.T2.clone()) - - y = bypass + x - self.step += 1 - with torch.cuda.amp.autocast(enabled=False): - if self.step % 4 == 0 or __name__ == "__main__": - self._update_covar_stats(y) - if self.step % 40 == 0 or __name__ == "__main__": - # note: important that 40 is a multiple of 4 - self._update_transforms() - - y = y.transpose(self.channel_dim, -1) - return y - - - - - - - def _test_activation_balancer_sign(): probs = torch.arange(0, 1, 0.01) @@ -1101,41 +1010,12 @@ def _test_decorrelate(): -def _test_join_dropout(): - D = 384 - x = torch.randn(30000, D) - - # give it a non-unit covariance. - m = torch.randn(D, D) * (D ** -0.5) - _, S, _ = m.svd() - print("M eigs = ", S[::10]) - x = torch.matmul(x, m) - - - for dropout_rate in [0.2, 0.1, 0.01, 0.05]: - m1 = torch.nn.Dropout(dropout_rate) - m2 = JoinDropout(D, apply_prob=1.0, dropout_rate=dropout_rate) - bypass = torch.zeros_like(x) - for mode in ['train', 'eval']: - y1 = m1(x) - for _ in range(2): - y2 = m2(bypass, x) - xmag = (x*x).mean() - y1mag = (y1*y1).mean() - cross1 = (x*y1).mean() - y2mag = (y2*y2).mean() - cross2 = (x*y2).mean() - print(f"rate={dropout_rate}, mode={mode}, xmag = {xmag}, y1mag = {y1mag}, y2mag = {y2mag}, cross1={cross1}, cross2={cross2}, ratio1={y1mag/cross1}, ratio2={y2mag/cross2}") - m1.eval() - m2.eval() - if __name__ == "__main__": logging.getLogger().setLevel(logging.INFO) torch.set_num_threads(1) torch.set_num_interop_threads(1) _test_decorrelate() - _test_join_dropout() _test_gauss_proj_drop() _test_activation_balancer_sign() _test_activation_balancer_magnitude() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless5/conformer.py index 0a85841fe..427632cfe 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/conformer.py @@ -29,7 +29,6 @@ from scaling import ( ScaledConv1d, ScaledConv2d, ScaledLinear, - JoinDropout, Decorrelate, ) from torch import Tensor, nn