mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-08 08:34:19 +00:00
Add Decorrelate module that adds something to gradients in backward pass
This commit is contained in:
parent
9fb8645168
commit
46ca1cd4c4
@ -713,6 +713,104 @@ class GaussProjDrop(torch.nn.Module):
|
||||
x = (x_next * self.rand_scale + x_bypass)
|
||||
return x
|
||||
|
||||
class DecorrelateFunction(torch.autograd.Function):
|
||||
# does nothing in forward pass. In backward pass it modifies
|
||||
# the gradients in such a way as to encourage the dims of x
|
||||
# to become uncorrelated, taken over all the stats.
|
||||
@staticmethod
|
||||
def forward(ctx, x: Tensor, old_cov: Tensor,
|
||||
scale: float, eps: float, beta: float,
|
||||
channel_dim: int) -> Tensor:
|
||||
ctx.save_for_backward(x, old_cov)
|
||||
ctx.scale = scale
|
||||
ctx.eps = eps
|
||||
ctx.beta = beta
|
||||
ctx.channel_dim = channel_dim
|
||||
return x
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None, None, None]:
|
||||
x, old_cov = ctx.saved_tensors
|
||||
with torch.enable_grad():
|
||||
x = x.transpose(-1, ctx.channel_dim)
|
||||
x_grad = x_grad.transpose(-1, ctx.channel_dim)
|
||||
num_channels = x.shape[-1]
|
||||
full_shape = x.shape
|
||||
x = x.reshape(-1, num_channels)
|
||||
x = x.detach()
|
||||
x.requires_grad = True
|
||||
x_grad = x_grad.reshape(-1, num_channels)
|
||||
|
||||
cov = old_cov * ctx.beta + torch.matmul(x.t(), x) * (1-ctx.beta)
|
||||
inv_sqrt_diag = (cov.diag() + ctx.eps) ** -0.5
|
||||
norm_cov = cov * (inv_sqrt_diag * inv_sqrt_diag.unsqueeze(-1))
|
||||
print("Decorrelate: norm_cov = ", norm_cov)
|
||||
|
||||
loss = ((norm_cov - norm_cov.diag().diag()) ** 2).sum() / num_channels - 1
|
||||
if random.random() < 0.01:
|
||||
logging.info(f"Decorrelate: loss = {loss}")
|
||||
loss.backward()
|
||||
x_grad_new = x.grad
|
||||
assert x.grad is not None
|
||||
|
||||
# Now, normalize the magnitudes of the rows of the new grad
|
||||
# contribution, to have magnitudes equals to ctx.scale times
|
||||
# `loss ** 0.5` times the magnitude of the original grad.
|
||||
x_grad_new_scale = (x_grad_new ** 2).sum(dim=1)
|
||||
x_grad_old_scale = (x_grad ** 2).sum(dim=1)
|
||||
decorr_loss_scale = ctx.scale * loss.detach() ** 0.5
|
||||
scale = decorr_loss_scale * (x_grad_old_scale / (x_grad_new_scale + 1.0e-10)) ** 0.5
|
||||
x_grad_new = x_grad_new * scale.unsqueeze(-1)
|
||||
|
||||
x_grad = x_grad + x_grad_new
|
||||
# reshape..
|
||||
x_grad = x_grad.reshape(full_shape)
|
||||
x_grad = x_grad.transpose(-1, ctx.channel_dim)
|
||||
|
||||
return x_grad, None, None, None, None, None
|
||||
|
||||
|
||||
|
||||
class Decorrelate(torch.nn.Module):
|
||||
"""
|
||||
This module does nothing in the forward pass, but in the backward pass, modifies
|
||||
the derivatives in such a way as to encourage the dimensions of its input to become
|
||||
decorrelated.
|
||||
"""
|
||||
def __init__(self,
|
||||
num_channels: int,
|
||||
scale: float = 0.1,
|
||||
eps: float = 1.0e-05,
|
||||
beta: float = 0.95,
|
||||
channel_dim: int = -1):
|
||||
super(Decorrelate, self).__init__()
|
||||
self.scale = scale
|
||||
self.eps = eps
|
||||
self.beta = beta
|
||||
self.channel_dim = channel_dim
|
||||
|
||||
self.register_buffer('cov', torch.zeros(num_channels, num_channels))
|
||||
self.step = 0
|
||||
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
if not self.training:
|
||||
return x
|
||||
else:
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
ans = DecorrelateFunction.apply(x, self.cov.clone(),
|
||||
self.scale, self.eps, self.beta,
|
||||
self.channel_dim) # == x.
|
||||
|
||||
x = x.transpose(self.channel_dim, -1)
|
||||
x = x.reshape(-1, x.shape[-1])
|
||||
cov = torch.matmul(x.t(), x) / x.shape[0]
|
||||
self.cov.mul_(self.beta).add_(cov, alpha=(1-self.beta))
|
||||
self.step += 1
|
||||
|
||||
return ans # ans == x.
|
||||
|
||||
|
||||
class JoinDropout(torch.nn.Module):
|
||||
"""
|
||||
This module implements something like:
|
||||
@ -735,7 +833,7 @@ class JoinDropout(torch.nn.Module):
|
||||
beta: A value 0 < beta < 1 that controls decay of covariance stats
|
||||
channel_dim: The dimension of the input corresponding to the channel, e.g.
|
||||
-1, 0, 1, 2.
|
||||
"""
|
||||
"""
|
||||
def __init__(self,
|
||||
num_channels: int,
|
||||
apply_prob: float = 0.75,
|
||||
@ -955,9 +1053,19 @@ def _test_join_dropout():
|
||||
x = torch.randn(30000, D)
|
||||
|
||||
# give it a non-unit covariance.
|
||||
m = torch.randn(D, D)
|
||||
m = torch.randn(D, D) * (D ** -0.5)
|
||||
_, S, _ = m.svd()
|
||||
print("M eigs = ", S[::10])
|
||||
x = torch.matmul(x, m)
|
||||
|
||||
|
||||
if True:
|
||||
# check that class Decorrelate does not crash when running..
|
||||
decorrelate = Decorrelate(D)
|
||||
x.requires_grad = True
|
||||
y = decorrelate(x)
|
||||
y.sum().backward()
|
||||
|
||||
for dropout_rate in [0.2, 0.1, 0.01, 0.05]:
|
||||
m1 = torch.nn.Dropout(dropout_rate)
|
||||
m2 = JoinDropout(D, apply_prob=1.0, dropout_rate=dropout_rate)
|
||||
|
@ -19,7 +19,7 @@ import copy
|
||||
import math
|
||||
import warnings
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import logging
|
||||
import torch
|
||||
from encoder_interface import EncoderInterface
|
||||
from scaling import (
|
||||
@ -30,7 +30,7 @@ from scaling import (
|
||||
ScaledConv2d,
|
||||
ScaledLinear,
|
||||
JoinDropout,
|
||||
|
||||
Decorrelate,
|
||||
)
|
||||
from torch import Tensor, nn
|
||||
|
||||
@ -1032,11 +1032,13 @@ class Conv2dSubsampling(nn.Module):
|
||||
# itself has learned scale, so the extra degree of freedom is not
|
||||
# needed.
|
||||
self.out_norm = BasicNorm(out_channels, learn_eps=False)
|
||||
self.decorrelate = Decorrelate(out_channels)
|
||||
# constrain median of output to be close to zero.
|
||||
self.out_balancer = ActivationBalancer(
|
||||
channel_dim=-1, min_positive=0.45, max_positive=0.55
|
||||
)
|
||||
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Subsample x.
|
||||
|
||||
@ -1055,6 +1057,7 @@ class Conv2dSubsampling(nn.Module):
|
||||
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
|
||||
# Now x is of shape (N, ((T-1)//2 - 1))//2, odim)
|
||||
x = self.out_norm(x)
|
||||
x = self.decorrelate(x)
|
||||
x = self.out_balancer(x)
|
||||
return x
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user