diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py index 77a67782a..fc47660f7 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py @@ -19,7 +19,7 @@ import copy import math import warnings from typing import List, Optional, Tuple - +import logging import torch from encoder_interface import EncoderInterface from scaling import ( @@ -29,6 +29,7 @@ from scaling import ( ScaledConv1d, ScaledConv2d, ScaledLinear, + Decorrelate, ) from torch import Tensor, nn @@ -93,6 +94,9 @@ class Conformer(EncoderInterface): aux_layers=list(range(0, num_encoder_layers - 1, aux_layer_period)), ) + self.decorrelate = Decorrelate(d_model, scale=0.05) + + def forward( self, x: torch.Tensor, x_lens: torch.Tensor, warmup: float = 1.0 ) -> Tuple[torch.Tensor, torch.Tensor]: @@ -128,6 +132,8 @@ class Conformer(EncoderInterface): x, pos_emb, src_key_padding_mask=mask, warmup=warmup ) # (T, N, C) + x = self.decorrelate(x) + x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C) return x, lengths @@ -201,6 +207,7 @@ class ConformerEncoderLayer(nn.Module): self.dropout = nn.Dropout(dropout) + def forward( self, src: Tensor, @@ -302,7 +309,6 @@ class ConformerEncoder(nn.Module): num_channels = encoder_layer.norm_final.num_channels self.combiner = RandomCombine( num_inputs=len(self.aux_layers), - num_channels=num_channels, final_weight=0.5, pure_prob=0.333, stddev=2.0, @@ -371,7 +377,7 @@ class RelPositionalEncoding(torch.nn.Module): """Construct an PositionalEncoding object.""" super(RelPositionalEncoding, self).__init__() self.d_model = d_model - self.dropout = torch.nn.Dropout(p=dropout_rate) + self.dropout = torch.nn.Dropout(dropout_rate) self.pe = None self.extend_pe(torch.tensor(0.0).expand(1, max_len)) @@ -1032,6 +1038,7 @@ class Conv2dSubsampling(nn.Module): channel_dim=-1, min_positive=0.45, max_positive=0.55 ) + def forward(self, x: torch.Tensor) -> torch.Tensor: """Subsample x. @@ -1073,7 +1080,6 @@ class RandomCombine(nn.Module): def __init__( self, num_inputs: int, - num_channels: int, final_weight: float = 0.5, pure_prob: float = 0.5, stddev: float = 2.0, @@ -1084,8 +1090,6 @@ class RandomCombine(nn.Module): The number of tensor inputs, which equals the number of layers' outputs that are fed into this module. E.g. in an 18-layer neural net if we output layers 16, 12, 18, num_inputs would be 3. - num_channels: - The number of channels on the input, e.g. 512. final_weight: The amount of weight or probability we assign to the final layer when randomly choosing layers or when choosing @@ -1116,13 +1120,6 @@ class RandomCombine(nn.Module): assert 0 < final_weight < 1, final_weight assert num_inputs >= 1 - self.linear = nn.ModuleList( - [ - nn.Linear(num_channels, num_channels, bias=True) - for _ in range(num_inputs - 1) - ] - ) - self.num_inputs = num_inputs self.final_weight = final_weight self.pure_prob = pure_prob @@ -1135,12 +1132,7 @@ class RandomCombine(nn.Module): .log() .item() ) - self._reset_parameters() - def _reset_parameters(self): - for i in range(len(self.linear)): - nn.init.eye_(self.linear[i].weight) - nn.init.constant_(self.linear[i].bias, 0.0) def forward(self, inputs: List[Tensor]) -> Tensor: """Forward function. @@ -1161,14 +1153,9 @@ class RandomCombine(nn.Module): num_channels = inputs[0].shape[-1] num_frames = inputs[0].numel() // num_channels - mod_inputs = [] - for i in range(num_inputs - 1): - mod_inputs.append(self.linear[i](inputs[i])) - mod_inputs.append(inputs[num_inputs - 1]) - ndim = inputs[0].ndim # stacked_inputs: (num_frames, num_channels, num_inputs) - stacked_inputs = torch.stack(mod_inputs, dim=ndim).reshape( + stacked_inputs = torch.stack(inputs, dim=ndim).reshape( (num_frames, num_channels, num_inputs) ) @@ -1282,7 +1269,6 @@ def _test_random_combine(final_weight: float, pure_prob: float, stddev: float): num_channels = 50 m = RandomCombine( num_inputs=num_inputs, - num_channels=num_channels, final_weight=final_weight, pure_prob=pure_prob, stddev=stddev, @@ -1329,5 +1315,8 @@ def _test_conformer_main(): if __name__ == "__main__": + logging.getLogger().setLevel(logging.INFO) + torch.set_num_threads(1) + torch.set_num_interop_threads(1) _test_conformer_main() _test_random_combine_main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py index 7ba71a94a..0556f69ba 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py @@ -18,7 +18,9 @@ import collections from itertools import repeat from typing import Optional, Tuple +import logging +import random import torch import torch.nn as nn from torch import Tensor @@ -335,6 +337,263 @@ class DoubleSwish(torch.nn.Module): +class GaussProjDrop(torch.nn.Module): + """ + This has an effect similar to torch.nn.Dropout, but does not privilege the on-axis directions. + The directions of dropout are fixed when the class is initialized, and are orthogonal. + + dropout_rate: the dropout probability (actually will define the number of zeroed-out directions) + channel_dim: the axis corresponding to the channel, e.g. -1, 0, 1, 2. + """ + def __init__(self, + num_channels: int, + dropout_rate: float = 0.1, + channel_dim: int = -1): + super(GaussProjDrop, self).__init__() + self.dropout_rate = dropout_rate + # this formula for rand_scale was found empirically, trying to match the + # statistics of dropout in terms of cross-correlation with the input, see + # _test_gauss_proj_drop() + self.rand_scale = (dropout_rate / (1-dropout_rate)) ** 0.5 # * (num_channels ** -0.5) + + self.channel_dim = channel_dim + + rand_mat = torch.randn(num_channels, num_channels) + U, _, _ = rand_mat.svd() + self.register_buffer('U', U) # a random orthogonal square matrix. will be a buffer. + + + def _randperm_like(self, x: Tensor): + """ + Returns random permutations of the integers [0,1,..x.shape[-1]-1], + with the same shape as x. All dimensions of x other than the last dimension + will be treated as batch dimensions. + + Torch's randperm does not support a batch dimension, so we pseudo-randomly simulate it. + + For now, requires x.shape[-1] to be either a power of 2 or 3 times a power of 2, as + we normally set channel dims. This is required for some number theoretic stuff. + """ + n = x.shape[-1] + + assert n & (n-1) == 0 or (n//3 & (n//3 - 1)) == 0 + + b = x.numel() // n + randint = random.randint(0, 1000) + perm = torch.randperm(n, device=x.device) + # ensure all elements of batch_rand are coprime to n; this will ensure + # that multiplying the permutation by batch_rand and taking modulo + # n leaves us with permutations. + batch_rand = torch.arange(b, device=x.device) * (randint * 6) + 1 + batch_rand = batch_rand.unsqueeze(-1) + ans = (perm * batch_rand) % n + ans = ans.reshape(x.shape) + return ans + + + def forward(self, x: Tensor) -> Tensor: + if not self.training: + return x + else: + x = x.transpose(self.channel_dim, -1) # (..., num_channels) + x_bypass = x # will be used for "+ I" + perm = self._randperm_like(x) + x = torch.gather(x, -1, perm) + # self.U will act like a different matrix for every row of x, because of the random + # permutation. + x = torch.matmul(x, self.U) + x_next = torch.empty_like(x) + # scatter_ uses perm in opposite way + # from gather, inverting it. + x_next.scatter_(-1, perm, x) + x = (x_next * self.rand_scale + x_bypass) + return x + + +def _compute_correlation_loss(cov: Tensor, + eps: float) -> Tensor: + """ + Computes the correlation `loss`, which would be zero if the channel dimensions + are un-correlated, and equals num_channels if they are maximally correlated + (i.e., they all point in the same direction) + Args: + cov: Uncentered covariance matrix of shape (num_channels, num_channels), + does not have to be normalized by count. + """ + inv_sqrt_diag = (cov.diag() + eps) ** -0.5 + norm_cov = cov * (inv_sqrt_diag * inv_sqrt_diag.unsqueeze(-1)) + num_channels = cov.shape[0] + loss = ((norm_cov - norm_cov.diag().diag()) ** 2).sum() / num_channels + return loss + +def _update_cov_stats(cov: Tensor, + x: Tensor, + beta: float) -> Tensor: + """ + Updates covariance stats as a decaying sum, returning the result. + cov: Old covariance stats, to be added to and returned, of shape + (num_channels, num_channels) + x: Tensor of features/activations, of shape (num_frames, num_channels) + beta: The decay constant for the stats, e.g. 0.8. + """ + new_cov = torch.matmul(x.t(), x) + return cov * beta + new_cov * (1-beta) + + +class DecorrelateFunction(torch.autograd.Function): + """ + Function object for a function that does nothing in the forward pass; + but, in the backward pass, adds derivatives that encourage the channel dimensions + to be un-correlated with each other. + This should not be used in a half-precision-enabled area, use + with torch.cuda.amp.autocast(enabled=False). + + Args: + x: The input tensor, which is also the function output. It can have + arbitrary shape, but its dimension `channel_dim` (e.g. -1) will be + interpreted as the channel dimension. + old_cov: Covariance statistics from previous frames, accumulated as a decaying + sum over frames (not average), decaying by beta each time. + scale: The scale on the derivatives arising from this, expressed as + a fraction of the norm of the derivative at the output (applied per + frame). We will further scale down the derivative if the normalized + loss is less than 1. + eps: Epsilon value to prevent division by zero when estimating diagonal + covariance. + beta: Decay constant that determines how we combine stats from x with + stats from cov; e.g. 0.8. + channel_dim: The dimension/axis of x corresponding to the channel, e.g. 0, 1, 2, -1. + """ + @staticmethod + def forward(ctx, x: Tensor, old_cov: Tensor, + scale: float, eps: float, beta: float, + channel_dim: int) -> Tensor: + ctx.save_for_backward(x.detach(), old_cov.detach()) + ctx.scale = scale + ctx.eps = eps + ctx.beta = beta + ctx.channel_dim = channel_dim + return x + + @staticmethod + def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None, None, None]: + x, old_cov = ctx.saved_tensors + + # Reshape x and x_grad to be (num_frames, num_channels) + x = x.transpose(-1, ctx.channel_dim) + x_grad = x_grad.transpose(-1, ctx.channel_dim) + num_channels = x.shape[-1] + full_shape = x.shape + x = x.reshape(-1, num_channels) + x_grad = x_grad.reshape(-1, num_channels) + x.requires_grad = True + + # Now, normalize the contributions of frames/pixels x to the covariance, + # to have magnitudes proportional to the norm of the gradient on that + # frame; the goal is to exclude "don't-care" frames such as padding frames from + # the computation. + x_grad_sqrt_norm = (x_grad ** 2).sum(dim=1) ** 0.25 + x_grad_sqrt_norm /= x_grad_sqrt_norm.mean() + x_grad_sqrt_norm_is_inf = (x_grad_sqrt_norm - x_grad_sqrt_norm != 0) + x_grad_sqrt_norm.masked_fill_(x_grad_sqrt_norm_is_inf, 1.0) + + with torch.enable_grad(): + # scale up frames with larger grads. + x_scaled = x * x_grad_sqrt_norm.unsqueeze(-1) + + cov = _update_cov_stats(old_cov, x_scaled, ctx.beta) + assert old_cov.dtype != torch.float16 + old_cov[:] = cov # update the stats outside! This is not really + # how backprop is supposed to work, but this input + # is not differentiable.. + loss = _compute_correlation_loss(cov, ctx.eps) + assert loss.dtype == torch.float32 + + if random.random() < 0.05: + logging.info(f"Decorrelate: loss = {loss}") + + loss.backward() + + decorr_x_grad = x.grad + + scale = ctx.scale * ((x_grad ** 2).mean() / ((decorr_x_grad ** 2).mean() + 1.0e-20)) ** 0.5 + decorr_x_grad = decorr_x_grad * scale + + x_grad = x_grad + decorr_x_grad + + # reshape back to original shape + x_grad = x_grad.reshape(full_shape) + x_grad = x_grad.transpose(-1, ctx.channel_dim) + + return x_grad, None, None, None, None, None + + + +class Decorrelate(torch.nn.Module): + """ + This module does nothing in the forward pass, but in the backward pass, modifies + the derivatives in such a way as to encourage the dimensions of its input to become + decorrelated. + + Args: + num_channels: The number of channels, e.g. 256. + apply_prob_decay: The probability with which we apply this each time, in + training mode, will decay as apply_prob_decay/(apply_prob_decay + step). + scale: This number determines the scale of the gradient contribution from + this module, relative to whatever the gradient was before; + this is applied per frame or pixel, by scaling gradients. + eps: An epsilon used to prevent division by zero. + beta: A value 0 < beta < 1 that controls decay of covariance stats + channel_dim: The dimension of the input corresponding to the channel, e.g. + -1, 0, 1, 2. + + """ + def __init__(self, + num_channels: int, + scale: float = 0.1, + apply_prob_decay: int = 1000, + eps: float = 1.0e-05, + beta: float = 0.95, + channel_dim: int = -1): + super(Decorrelate, self).__init__() + self.scale = scale + self.apply_prob_decay = apply_prob_decay + self.eps = eps + self.beta = beta + self.channel_dim = channel_dim + + self.register_buffer('cov', torch.zeros(num_channels, num_channels)) + # step_buf is a copy of step, included so it will be loaded/saved with + # the model. + self.register_buffer('step_buf', torch.tensor(0.0)) + self.step = 0 + + + def load_state_dict(self, *args, **kwargs): + super(Decorrelate, self).load_state_dict(*args, **kwargs) + self.step = int(self.step_buf.item()) + + + def forward(self, x: Tensor) -> Tensor: + if not self.training: + return x + else: + apply_prob = self.apply_prob_decay / (self.step + self.apply_prob_decay) + self.step += 1 + self.step_buf.fill_(float(self.step)) + if random.random() > apply_prob: + return x + with torch.cuda.amp.autocast(enabled=False): + x = x.to(torch.float32) + # the function updates self.cov in its backward pass (it needs the gradient + # norm, for frame weighting). + ans = DecorrelateFunction.apply(x, self.cov, + self.scale, self.eps, self.beta, + self.channel_dim) # == x. + return ans + + + def _test_activation_balancer_sign(): probs = torch.arange(0, 1, 0.01) N = 1000 @@ -344,7 +603,7 @@ def _test_activation_balancer_sign(): m = ActivationBalancer( channel_dim=0, min_positive=0.05, - max_positive=0.95, + max_positive=0.98, max_factor=0.2, min_abs=0.0, ) @@ -408,7 +667,56 @@ def _test_double_swish_deriv(): torch.autograd.gradcheck(m, x) +def _test_gauss_proj_drop(): + D = 384 + x = torch.randn(30000, D) + + + for dropout_rate in [0.2, 0.1, 0.01, 0.05]: + m1 = torch.nn.Dropout(dropout_rate) + m2 = GaussProjDrop(D, dropout_rate) + for mode in ['train', 'eval']: + y1 = m1(x) + y2 = m2(x) + xmag = (x*x).mean() + y1mag = (y1*y1).mean() + cross1 = (x*y1).mean() + y2mag = (y2*y2).mean() + cross2 = (x*y2).mean() + print(f"rate={dropout_rate}, mode={mode}, xmag = {xmag}, y1mag = {y1mag}, y2mag = {y2mag}, cross1={cross1}, cross2={cross2}") + m1.eval() + m2.eval() + + +def _test_decorrelate(): + D = 384 + x = torch.randn(30000, D) + # give it a non-unit covariance. + m = torch.randn(D, D) * (D ** -0.5) + _, S, _ = m.svd() + print("M eigs = ", S[::10]) + x = torch.matmul(x, m) + + + # check that class Decorrelate does not crash when running.. + decorrelate = Decorrelate(D) + x.requires_grad = True + y = decorrelate(x) + y.sum().backward() + + decorrelate2 = Decorrelate(D) + decorrelate2.load_state_dict(decorrelate.state_dict()) + assert decorrelate2.step == decorrelate.step + + + + if __name__ == "__main__": + logging.getLogger().setLevel(logging.INFO) + torch.set_num_threads(1) + torch.set_num_interop_threads(1) + _test_decorrelate() + _test_gauss_proj_drop() _test_activation_balancer_sign() _test_activation_balancer_magnitude() _test_basic_norm()