diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py index 0556f69ba..71b4db7b3 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py @@ -410,188 +410,6 @@ class GaussProjDrop(torch.nn.Module): return x -def _compute_correlation_loss(cov: Tensor, - eps: float) -> Tensor: - """ - Computes the correlation `loss`, which would be zero if the channel dimensions - are un-correlated, and equals num_channels if they are maximally correlated - (i.e., they all point in the same direction) - Args: - cov: Uncentered covariance matrix of shape (num_channels, num_channels), - does not have to be normalized by count. - """ - inv_sqrt_diag = (cov.diag() + eps) ** -0.5 - norm_cov = cov * (inv_sqrt_diag * inv_sqrt_diag.unsqueeze(-1)) - num_channels = cov.shape[0] - loss = ((norm_cov - norm_cov.diag().diag()) ** 2).sum() / num_channels - return loss - -def _update_cov_stats(cov: Tensor, - x: Tensor, - beta: float) -> Tensor: - """ - Updates covariance stats as a decaying sum, returning the result. - cov: Old covariance stats, to be added to and returned, of shape - (num_channels, num_channels) - x: Tensor of features/activations, of shape (num_frames, num_channels) - beta: The decay constant for the stats, e.g. 0.8. - """ - new_cov = torch.matmul(x.t(), x) - return cov * beta + new_cov * (1-beta) - - -class DecorrelateFunction(torch.autograd.Function): - """ - Function object for a function that does nothing in the forward pass; - but, in the backward pass, adds derivatives that encourage the channel dimensions - to be un-correlated with each other. - This should not be used in a half-precision-enabled area, use - with torch.cuda.amp.autocast(enabled=False). - - Args: - x: The input tensor, which is also the function output. It can have - arbitrary shape, but its dimension `channel_dim` (e.g. -1) will be - interpreted as the channel dimension. - old_cov: Covariance statistics from previous frames, accumulated as a decaying - sum over frames (not average), decaying by beta each time. - scale: The scale on the derivatives arising from this, expressed as - a fraction of the norm of the derivative at the output (applied per - frame). We will further scale down the derivative if the normalized - loss is less than 1. - eps: Epsilon value to prevent division by zero when estimating diagonal - covariance. - beta: Decay constant that determines how we combine stats from x with - stats from cov; e.g. 0.8. - channel_dim: The dimension/axis of x corresponding to the channel, e.g. 0, 1, 2, -1. - """ - @staticmethod - def forward(ctx, x: Tensor, old_cov: Tensor, - scale: float, eps: float, beta: float, - channel_dim: int) -> Tensor: - ctx.save_for_backward(x.detach(), old_cov.detach()) - ctx.scale = scale - ctx.eps = eps - ctx.beta = beta - ctx.channel_dim = channel_dim - return x - - @staticmethod - def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None, None, None]: - x, old_cov = ctx.saved_tensors - - # Reshape x and x_grad to be (num_frames, num_channels) - x = x.transpose(-1, ctx.channel_dim) - x_grad = x_grad.transpose(-1, ctx.channel_dim) - num_channels = x.shape[-1] - full_shape = x.shape - x = x.reshape(-1, num_channels) - x_grad = x_grad.reshape(-1, num_channels) - x.requires_grad = True - - # Now, normalize the contributions of frames/pixels x to the covariance, - # to have magnitudes proportional to the norm of the gradient on that - # frame; the goal is to exclude "don't-care" frames such as padding frames from - # the computation. - x_grad_sqrt_norm = (x_grad ** 2).sum(dim=1) ** 0.25 - x_grad_sqrt_norm /= x_grad_sqrt_norm.mean() - x_grad_sqrt_norm_is_inf = (x_grad_sqrt_norm - x_grad_sqrt_norm != 0) - x_grad_sqrt_norm.masked_fill_(x_grad_sqrt_norm_is_inf, 1.0) - - with torch.enable_grad(): - # scale up frames with larger grads. - x_scaled = x * x_grad_sqrt_norm.unsqueeze(-1) - - cov = _update_cov_stats(old_cov, x_scaled, ctx.beta) - assert old_cov.dtype != torch.float16 - old_cov[:] = cov # update the stats outside! This is not really - # how backprop is supposed to work, but this input - # is not differentiable.. - loss = _compute_correlation_loss(cov, ctx.eps) - assert loss.dtype == torch.float32 - - if random.random() < 0.05: - logging.info(f"Decorrelate: loss = {loss}") - - loss.backward() - - decorr_x_grad = x.grad - - scale = ctx.scale * ((x_grad ** 2).mean() / ((decorr_x_grad ** 2).mean() + 1.0e-20)) ** 0.5 - decorr_x_grad = decorr_x_grad * scale - - x_grad = x_grad + decorr_x_grad - - # reshape back to original shape - x_grad = x_grad.reshape(full_shape) - x_grad = x_grad.transpose(-1, ctx.channel_dim) - - return x_grad, None, None, None, None, None - - - -class Decorrelate(torch.nn.Module): - """ - This module does nothing in the forward pass, but in the backward pass, modifies - the derivatives in such a way as to encourage the dimensions of its input to become - decorrelated. - - Args: - num_channels: The number of channels, e.g. 256. - apply_prob_decay: The probability with which we apply this each time, in - training mode, will decay as apply_prob_decay/(apply_prob_decay + step). - scale: This number determines the scale of the gradient contribution from - this module, relative to whatever the gradient was before; - this is applied per frame or pixel, by scaling gradients. - eps: An epsilon used to prevent division by zero. - beta: A value 0 < beta < 1 that controls decay of covariance stats - channel_dim: The dimension of the input corresponding to the channel, e.g. - -1, 0, 1, 2. - - """ - def __init__(self, - num_channels: int, - scale: float = 0.1, - apply_prob_decay: int = 1000, - eps: float = 1.0e-05, - beta: float = 0.95, - channel_dim: int = -1): - super(Decorrelate, self).__init__() - self.scale = scale - self.apply_prob_decay = apply_prob_decay - self.eps = eps - self.beta = beta - self.channel_dim = channel_dim - - self.register_buffer('cov', torch.zeros(num_channels, num_channels)) - # step_buf is a copy of step, included so it will be loaded/saved with - # the model. - self.register_buffer('step_buf', torch.tensor(0.0)) - self.step = 0 - - - def load_state_dict(self, *args, **kwargs): - super(Decorrelate, self).load_state_dict(*args, **kwargs) - self.step = int(self.step_buf.item()) - - - def forward(self, x: Tensor) -> Tensor: - if not self.training: - return x - else: - apply_prob = self.apply_prob_decay / (self.step + self.apply_prob_decay) - self.step += 1 - self.step_buf.fill_(float(self.step)) - if random.random() > apply_prob: - return x - with torch.cuda.amp.autocast(enabled=False): - x = x.to(torch.float32) - # the function updates self.cov in its backward pass (it needs the gradient - # norm, for frame weighting). - ans = DecorrelateFunction.apply(x, self.cov, - self.scale, self.eps, self.beta, - self.channel_dim) # == x. - return ans - def _test_activation_balancer_sign(): @@ -688,34 +506,11 @@ def _test_gauss_proj_drop(): m2.eval() -def _test_decorrelate(): - D = 384 - x = torch.randn(30000, D) - # give it a non-unit covariance. - m = torch.randn(D, D) * (D ** -0.5) - _, S, _ = m.svd() - print("M eigs = ", S[::10]) - x = torch.matmul(x, m) - - - # check that class Decorrelate does not crash when running.. - decorrelate = Decorrelate(D) - x.requires_grad = True - y = decorrelate(x) - y.sum().backward() - - decorrelate2 = Decorrelate(D) - decorrelate2.load_state_dict(decorrelate.state_dict()) - assert decorrelate2.step == decorrelate.step - - - if __name__ == "__main__": logging.getLogger().setLevel(logging.INFO) torch.set_num_threads(1) torch.set_num_interop_threads(1) - _test_decorrelate() _test_gauss_proj_drop() _test_activation_balancer_sign() _test_activation_balancer_magnitude()