Add Decorrelate module that adds something to gradients in backward pass

This commit is contained in:
Daniel Povey 2022-06-08 19:44:58 +08:00
parent 9fb8645168
commit 46ca1cd4c4
2 changed files with 115 additions and 4 deletions

View File

@ -713,6 +713,104 @@ class GaussProjDrop(torch.nn.Module):
x = (x_next * self.rand_scale + x_bypass)
return x
class DecorrelateFunction(torch.autograd.Function):
# does nothing in forward pass. In backward pass it modifies
# the gradients in such a way as to encourage the dims of x
# to become uncorrelated, taken over all the stats.
@staticmethod
def forward(ctx, x: Tensor, old_cov: Tensor,
scale: float, eps: float, beta: float,
channel_dim: int) -> Tensor:
ctx.save_for_backward(x, old_cov)
ctx.scale = scale
ctx.eps = eps
ctx.beta = beta
ctx.channel_dim = channel_dim
return x
@staticmethod
def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None, None, None]:
x, old_cov = ctx.saved_tensors
with torch.enable_grad():
x = x.transpose(-1, ctx.channel_dim)
x_grad = x_grad.transpose(-1, ctx.channel_dim)
num_channels = x.shape[-1]
full_shape = x.shape
x = x.reshape(-1, num_channels)
x = x.detach()
x.requires_grad = True
x_grad = x_grad.reshape(-1, num_channels)
cov = old_cov * ctx.beta + torch.matmul(x.t(), x) * (1-ctx.beta)
inv_sqrt_diag = (cov.diag() + ctx.eps) ** -0.5
norm_cov = cov * (inv_sqrt_diag * inv_sqrt_diag.unsqueeze(-1))
print("Decorrelate: norm_cov = ", norm_cov)
loss = ((norm_cov - norm_cov.diag().diag()) ** 2).sum() / num_channels - 1
if random.random() < 0.01:
logging.info(f"Decorrelate: loss = {loss}")
loss.backward()
x_grad_new = x.grad
assert x.grad is not None
# Now, normalize the magnitudes of the rows of the new grad
# contribution, to have magnitudes equals to ctx.scale times
# `loss ** 0.5` times the magnitude of the original grad.
x_grad_new_scale = (x_grad_new ** 2).sum(dim=1)
x_grad_old_scale = (x_grad ** 2).sum(dim=1)
decorr_loss_scale = ctx.scale * loss.detach() ** 0.5
scale = decorr_loss_scale * (x_grad_old_scale / (x_grad_new_scale + 1.0e-10)) ** 0.5
x_grad_new = x_grad_new * scale.unsqueeze(-1)
x_grad = x_grad + x_grad_new
# reshape..
x_grad = x_grad.reshape(full_shape)
x_grad = x_grad.transpose(-1, ctx.channel_dim)
return x_grad, None, None, None, None, None
class Decorrelate(torch.nn.Module):
"""
This module does nothing in the forward pass, but in the backward pass, modifies
the derivatives in such a way as to encourage the dimensions of its input to become
decorrelated.
"""
def __init__(self,
num_channels: int,
scale: float = 0.1,
eps: float = 1.0e-05,
beta: float = 0.95,
channel_dim: int = -1):
super(Decorrelate, self).__init__()
self.scale = scale
self.eps = eps
self.beta = beta
self.channel_dim = channel_dim
self.register_buffer('cov', torch.zeros(num_channels, num_channels))
self.step = 0
def forward(self, x: Tensor) -> Tensor:
if not self.training:
return x
else:
with torch.cuda.amp.autocast(enabled=False):
ans = DecorrelateFunction.apply(x, self.cov.clone(),
self.scale, self.eps, self.beta,
self.channel_dim) # == x.
x = x.transpose(self.channel_dim, -1)
x = x.reshape(-1, x.shape[-1])
cov = torch.matmul(x.t(), x) / x.shape[0]
self.cov.mul_(self.beta).add_(cov, alpha=(1-self.beta))
self.step += 1
return ans # ans == x.
class JoinDropout(torch.nn.Module):
"""
This module implements something like:
@ -735,7 +833,7 @@ class JoinDropout(torch.nn.Module):
beta: A value 0 < beta < 1 that controls decay of covariance stats
channel_dim: The dimension of the input corresponding to the channel, e.g.
-1, 0, 1, 2.
"""
"""
def __init__(self,
num_channels: int,
apply_prob: float = 0.75,
@ -955,9 +1053,19 @@ def _test_join_dropout():
x = torch.randn(30000, D)
# give it a non-unit covariance.
m = torch.randn(D, D)
m = torch.randn(D, D) * (D ** -0.5)
_, S, _ = m.svd()
print("M eigs = ", S[::10])
x = torch.matmul(x, m)
if True:
# check that class Decorrelate does not crash when running..
decorrelate = Decorrelate(D)
x.requires_grad = True
y = decorrelate(x)
y.sum().backward()
for dropout_rate in [0.2, 0.1, 0.01, 0.05]:
m1 = torch.nn.Dropout(dropout_rate)
m2 = JoinDropout(D, apply_prob=1.0, dropout_rate=dropout_rate)

View File

@ -19,7 +19,7 @@ import copy
import math
import warnings
from typing import List, Optional, Tuple
import logging
import torch
from encoder_interface import EncoderInterface
from scaling import (
@ -30,7 +30,7 @@ from scaling import (
ScaledConv2d,
ScaledLinear,
JoinDropout,
Decorrelate,
)
from torch import Tensor, nn
@ -1032,11 +1032,13 @@ class Conv2dSubsampling(nn.Module):
# itself has learned scale, so the extra degree of freedom is not
# needed.
self.out_norm = BasicNorm(out_channels, learn_eps=False)
self.decorrelate = Decorrelate(out_channels)
# constrain median of output to be close to zero.
self.out_balancer = ActivationBalancer(
channel_dim=-1, min_positive=0.45, max_positive=0.55
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Subsample x.
@ -1055,6 +1057,7 @@ class Conv2dSubsampling(nn.Module):
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
# Now x is of shape (N, ((T-1)//2 - 1))//2, odim)
x = self.out_norm(x)
x = self.decorrelate(x)
x = self.out_balancer(x)
return x