mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Code cleanup and refactoring
This commit is contained in:
parent
2621cb7f54
commit
0fd2cb141f
@ -713,15 +713,65 @@ class GaussProjDrop(torch.nn.Module):
|
|||||||
x = (x_next * self.rand_scale + x_bypass)
|
x = (x_next * self.rand_scale + x_bypass)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def _compute_correlation_loss(cov: Tensor,
|
||||||
|
eps: float) -> Tensor:
|
||||||
|
"""
|
||||||
|
Computes the correlation `loss`, which would be zero if the channel dimensions
|
||||||
|
are un-correlated, and equals num_channels if they are maximally correlated
|
||||||
|
(i.e., they all point in the same direction)
|
||||||
|
Args:
|
||||||
|
cov: Uncentered covariance matrix of shape (num_channels, num_channels),
|
||||||
|
does not have to be normalized by count.
|
||||||
|
"""
|
||||||
|
inv_sqrt_diag = (cov.diag() + eps) ** -0.5
|
||||||
|
norm_cov = cov * (inv_sqrt_diag * inv_sqrt_diag.unsqueeze(-1))
|
||||||
|
num_channels = cov.shape[0]
|
||||||
|
loss = ((norm_cov - norm_cov.diag().diag()) ** 2).sum() / num_channels
|
||||||
|
return loss
|
||||||
|
|
||||||
|
def _update_cov_stats(cov: Tensor,
|
||||||
|
x: Tensor,
|
||||||
|
beta: float) -> Tensor:
|
||||||
|
"""
|
||||||
|
Updates covariance stats as a decaying sum, returning the result.
|
||||||
|
cov: Old covariance stats, to be added to and returned, of shape
|
||||||
|
(num_channels, num_channels)
|
||||||
|
x: Tensor of features/activations, of shape (num_frames, num_channels)
|
||||||
|
beta: The decay constant for the stats, e.g. 0.8.
|
||||||
|
"""
|
||||||
|
return cov * beta + torch.matmul(x.t(), x) * (1-beta)
|
||||||
|
|
||||||
|
|
||||||
class DecorrelateFunction(torch.autograd.Function):
|
class DecorrelateFunction(torch.autograd.Function):
|
||||||
# does nothing in forward pass. In backward pass it modifies
|
"""
|
||||||
# the gradients in such a way as to encourage the dims of x
|
Function object for a function that does nothing in the forward pass;
|
||||||
# to become uncorrelated, taken over all the stats.
|
but, in the backward pass, adds derivatives that encourage the channel dimensions
|
||||||
|
to be un-correlated with each other.
|
||||||
|
This should not be used in a half-precision-enabled area, use
|
||||||
|
with torch.cuda.amp.autocast(enabled=False).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: The input tensor, which is also the function output. It can have
|
||||||
|
arbitrary shape, but its dimension `channel_dim` (e.g. -1) will be
|
||||||
|
interpreted as the channel dimension.
|
||||||
|
old_cov: Covariance statistics from previous frames, accumulated as a decaying
|
||||||
|
sum over frames (not average), decaying by beta each time.
|
||||||
|
scale: The scale on the derivatives arising from this, expressed as
|
||||||
|
a fraction of the norm of the derivative at the output (applied per
|
||||||
|
frame). We will further scale down the derivative if the normalized
|
||||||
|
loss is less than 1.
|
||||||
|
eps: Epsilon value to prevent division by zero when estimating diagonal
|
||||||
|
covariance.
|
||||||
|
beta: Decay constant that determines how we combine stats from x with
|
||||||
|
stats from cov; e.g. 0.8.
|
||||||
|
channel_dim: The dimension/axis of x corresponding to the channel, e.g. 0, 1, 2, -1.
|
||||||
|
"""
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(ctx, x: Tensor, old_cov: Tensor,
|
def forward(ctx, x: Tensor, old_cov: Tensor,
|
||||||
scale: float, eps: float, beta: float,
|
scale: float, eps: float, beta: float,
|
||||||
channel_dim: int) -> Tensor:
|
channel_dim: int) -> Tensor:
|
||||||
ctx.save_for_backward(x, old_cov)
|
ctx.save_for_backward(x.detach(), old_cov.detach())
|
||||||
ctx.scale = scale
|
ctx.scale = scale
|
||||||
ctx.eps = eps
|
ctx.eps = eps
|
||||||
ctx.beta = beta
|
ctx.beta = beta
|
||||||
@ -731,41 +781,44 @@ class DecorrelateFunction(torch.autograd.Function):
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None, None, None]:
|
def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None, None, None]:
|
||||||
x, old_cov = ctx.saved_tensors
|
x, old_cov = ctx.saved_tensors
|
||||||
with torch.enable_grad():
|
|
||||||
|
|
||||||
|
# Reshape x and x_grad to be (num_frames, num_channels)
|
||||||
x = x.transpose(-1, ctx.channel_dim)
|
x = x.transpose(-1, ctx.channel_dim)
|
||||||
x_grad = x_grad.transpose(-1, ctx.channel_dim)
|
x_grad = x_grad.transpose(-1, ctx.channel_dim)
|
||||||
num_channels = x.shape[-1]
|
num_channels = x.shape[-1]
|
||||||
full_shape = x.shape
|
full_shape = x.shape
|
||||||
x = x.reshape(-1, num_channels)
|
x = x.reshape(-1, num_channels)
|
||||||
x = x.detach()
|
|
||||||
old_cov = old_cov.detach()
|
|
||||||
x.requires_grad = True
|
|
||||||
x_grad = x_grad.reshape(-1, num_channels)
|
x_grad = x_grad.reshape(-1, num_channels)
|
||||||
|
x.requires_grad = True
|
||||||
|
|
||||||
cov = old_cov * ctx.beta + torch.matmul(x.t(), x) * (1-ctx.beta)
|
with torch.enable_grad():
|
||||||
inv_sqrt_diag = (cov.diag() + ctx.eps) ** -0.5
|
cov = _update_cov_stats(old_cov, x, ctx.beta)
|
||||||
norm_cov = cov * (inv_sqrt_diag * inv_sqrt_diag.unsqueeze(-1))
|
loss = _compute_correlation_loss(cov, ctx.eps)
|
||||||
|
|
||||||
loss = ((norm_cov - norm_cov.diag().diag()) ** 2).sum() / num_channels
|
|
||||||
if random.random() < 0.01:
|
if random.random() < 0.01:
|
||||||
logging.info(f"Decorrelate: loss = {loss}")
|
logging.info(f"Decorrelate: loss = {loss}")
|
||||||
loss.backward()
|
loss.backward()
|
||||||
x_grad_new = x.grad
|
|
||||||
|
decorr_x_grad = x.grad
|
||||||
assert x.grad is not None
|
assert x.grad is not None
|
||||||
|
|
||||||
# Now, normalize the magnitudes of the rows of the new grad
|
# Now, normalize the magnitudes of the rows of the new grad
|
||||||
# contribution, to have magnitudes equals to ctx.scale times
|
# contribution, to have magnitudes equals to ctx.scale times
|
||||||
# `loss ** 0.5` times the magnitude of the original grad.
|
# `loss ** 0.5` times the magnitude of the original grad.
|
||||||
x_grad_new_scale = (x_grad_new ** 2).sum(dim=1)
|
decorr_x_grad_sqnorm = (decorr_x_grad ** 2).sum(dim=1)
|
||||||
x_grad_old_scale = (x_grad ** 2).sum(dim=1)
|
x_grad_old_sqnorm = (x_grad ** 2).sum(dim=1)
|
||||||
|
|
||||||
|
# loss.detach().clamp(min=0.0, max=1.0) is a factor that means once
|
||||||
|
# the loss starts getting quite small (less than 1), we start using
|
||||||
|
# smaller derivatives.
|
||||||
decorr_loss_scale = ctx.scale * loss.detach().clamp(min=0.0, max=1.0)
|
decorr_loss_scale = ctx.scale * loss.detach().clamp(min=0.0, max=1.0)
|
||||||
|
scale = decorr_loss_scale * (x_grad_old_sqnorm / (decorr_x_grad_sqnorm + 1.0e-10)) ** 0.5
|
||||||
|
decorr_x_grad = decorr_x_grad * scale.unsqueeze(-1)
|
||||||
|
|
||||||
scale = decorr_loss_scale * (x_grad_old_scale / (x_grad_new_scale + 1.0e-10)) ** 0.5
|
x_grad = x_grad + decorr_x_grad
|
||||||
x_grad_new = x_grad_new * scale.unsqueeze(-1)
|
|
||||||
|
|
||||||
x_grad = x_grad + x_grad_new
|
# reshape back to original shape
|
||||||
# reshape..
|
|
||||||
x_grad = x_grad.reshape(full_shape)
|
x_grad = x_grad.reshape(full_shape)
|
||||||
x_grad = x_grad.transpose(-1, ctx.channel_dim)
|
x_grad = x_grad.transpose(-1, ctx.channel_dim)
|
||||||
|
|
||||||
@ -813,7 +866,6 @@ class Decorrelate(torch.nn.Module):
|
|||||||
self.step = 0
|
self.step = 0
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def load_state_dict(self, *args, **kwargs):
|
def load_state_dict(self, *args, **kwargs):
|
||||||
super(Decorrelate, self).load_state_dict(*args, **kwargs)
|
super(Decorrelate, self).load_state_dict(*args, **kwargs)
|
||||||
self.step = int(self.step_buf.item())
|
self.step = int(self.step_buf.item())
|
||||||
@ -842,149 +894,6 @@ class Decorrelate(torch.nn.Module):
|
|||||||
return ans # ans == x.
|
return ans # ans == x.
|
||||||
|
|
||||||
|
|
||||||
class JoinDropout(torch.nn.Module):
|
|
||||||
"""
|
|
||||||
This module implements something like:
|
|
||||||
y = bypass + dropout(x)
|
|
||||||
but does it in such a way as to encourage x to vary in directions that will tend
|
|
||||||
to make the dimensions of y as decorrelated as possible. We do this
|
|
||||||
by putting lots of dropout in directions in the space in which we
|
|
||||||
don't want x to vary (because it will tend to increase correlations between
|
|
||||||
dimensions in the output y).
|
|
||||||
|
|
||||||
|
|
||||||
Args:
|
|
||||||
num_channels: The number of channels, e.g. 256.
|
|
||||||
apply_prob_decay: The probability with which we apply this each time, in
|
|
||||||
training mode, will decay as apply_prob_decay/(apply_prob_decay + step).
|
|
||||||
dropout_rate: This number determines the average dropout probability
|
|
||||||
(it will actually vary across dimensions).
|
|
||||||
eps: An epsilon used to prevent division by zero.
|
|
||||||
beta: A value 0 < beta < 1 that controls decay of covariance stats
|
|
||||||
channel_dim: The dimension of the input corresponding to the channel, e.g.
|
|
||||||
-1, 0, 1, 2.
|
|
||||||
"""
|
|
||||||
def __init__(self,
|
|
||||||
num_channels: int,
|
|
||||||
apply_prob: float = 0.75,
|
|
||||||
dropout_rate: float = 0.1,
|
|
||||||
eps: float = 1.0e-04,
|
|
||||||
beta: float = 0.95,
|
|
||||||
channel_dim: int = -1):
|
|
||||||
super(JoinDropout, self).__init__()
|
|
||||||
self.apply_prob = apply_prob
|
|
||||||
self.dropout_rate = dropout_rate
|
|
||||||
self.channel_dim = channel_dim
|
|
||||||
self.eps = eps
|
|
||||||
self.beta = beta
|
|
||||||
|
|
||||||
self.register_buffer('T1', torch.eye(num_channels))
|
|
||||||
self.register_buffer('dropout_probs', torch.zeros(num_channels))
|
|
||||||
self.register_buffer('scales', torch.ones(num_channels))
|
|
||||||
self.register_buffer('T2', torch.eye(num_channels))
|
|
||||||
self.register_buffer('cov', torch.zeros(num_channels, num_channels))
|
|
||||||
self.step = 0
|
|
||||||
|
|
||||||
|
|
||||||
def _update_covar_stats(self, y: Tensor) -> None:
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
y: Tensor of shape (*, num_channels), of output.
|
|
||||||
Updates covariance stats self.cov
|
|
||||||
"""
|
|
||||||
y = y.detach()
|
|
||||||
y = y.reshape(-1, y.shape[-1])
|
|
||||||
y = y * (y.shape[0] ** -0.5) # avoid overflow in half precision
|
|
||||||
cov = torch.matmul(y.t(), y)
|
|
||||||
self.cov.mul_(self.beta).add_(cov, alpha=(1-self.beta))
|
|
||||||
|
|
||||||
def _update_transforms(self):
|
|
||||||
norm_cov, inv_sqrt_diag = self._normalize_covar(self.cov)
|
|
||||||
|
|
||||||
U, S, _ = norm_cov.svd() # because diag of norm_cov is 1.0, S.mean() == 1.0
|
|
||||||
|
|
||||||
if random.random() < 0.1:
|
|
||||||
logging.info(f"JoinDropout: max,min eig of normalized cov is: {S.max().item():.2e},{S.min().item():.2e}")
|
|
||||||
|
|
||||||
dropout_probs = (S.sqrt() - 0.99).clamp(min=0)
|
|
||||||
dropout_probs = dropout_probs * (self.dropout_rate / dropout_probs.mean())
|
|
||||||
dropout_probs = dropout_probs.clamp(max=0.5)
|
|
||||||
self.dropout_probs[:] = dropout_probs
|
|
||||||
self.scales[:] = 1.0 / (1 - dropout_probs)
|
|
||||||
|
|
||||||
|
|
||||||
# row indexes of U correspond to channels, column indexes correspond to
|
|
||||||
# singular values: cov = U * diag(S) * U.t() where * is matmul.
|
|
||||||
|
|
||||||
|
|
||||||
# Transform T1, which we'll incorporate as torch.matmul(x, self.T1), is:
|
|
||||||
# (i) multiply by inv_sqrt_diag which makes the covariance have
|
|
||||||
# a unit diagonal.
|
|
||||||
# (ii) multiply by U, which diagonalizes norm_cov (uncorrelated channels)
|
|
||||||
self.T1[:] = (inv_sqrt_diag.unsqueeze(-1) * U)
|
|
||||||
|
|
||||||
# Transform T2, which we'll incorporate as torch.matmul(x, self.T2), is:
|
|
||||||
# (i) multiply by U, which un-diagonalizes norm_cov
|
|
||||||
# (ii) divide by inv_sqrt_diag which makes the covariance have its original
|
|
||||||
# diagonal values.
|
|
||||||
self.T2[:] = (U.t() / inv_sqrt_diag)
|
|
||||||
|
|
||||||
|
|
||||||
if random.random() < 0.01:
|
|
||||||
d = torch.matmul(self.T1, self.T2) - torch.eye(self.T1.shape[0],
|
|
||||||
device=self.T1.device,
|
|
||||||
dtype=self.T1.dtype)
|
|
||||||
assert torch.all(d.abs() < 0.01)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def _normalize_covar(self, cov: Tensor) -> Tensor:
|
|
||||||
"""
|
|
||||||
Normlizes a covariance matrix so that its diagonal is 1, by multiplying by
|
|
||||||
its diagonal**-0.5 on both sides.
|
|
||||||
Args:
|
|
||||||
cov: matrix to normalize
|
|
||||||
Returns normalized_cov, inv_sqrt_diag
|
|
||||||
"""
|
|
||||||
diag = cov.diag()
|
|
||||||
inv_sqrt_diag = (diag + self.eps) ** -0.5
|
|
||||||
cov = cov * (inv_sqrt_diag * inv_sqrt_diag.unsqueeze(-1))
|
|
||||||
return cov, inv_sqrt_diag
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def forward(self, bypass: Tensor, x: Tensor) -> Tensor:
|
|
||||||
apply_prob = self.apply_prob
|
|
||||||
if not self.training or random.random() > apply_prob:
|
|
||||||
return bypass + x
|
|
||||||
else:
|
|
||||||
x = x.transpose(self.channel_dim, -1) # (..., num_channels)
|
|
||||||
bypass = bypass.transpose(self.channel_dim, -1)
|
|
||||||
|
|
||||||
x = torch.matmul(x, self.T1.clone())
|
|
||||||
|
|
||||||
mask = (torch.rand_like(x) > self.dropout_probs)
|
|
||||||
x = (x * mask) * self.scales.clone()
|
|
||||||
x = torch.matmul(x, self.T2.clone())
|
|
||||||
|
|
||||||
y = bypass + x
|
|
||||||
self.step += 1
|
|
||||||
with torch.cuda.amp.autocast(enabled=False):
|
|
||||||
if self.step % 4 == 0 or __name__ == "__main__":
|
|
||||||
self._update_covar_stats(y)
|
|
||||||
if self.step % 40 == 0 or __name__ == "__main__":
|
|
||||||
# note: important that 40 is a multiple of 4
|
|
||||||
self._update_transforms()
|
|
||||||
|
|
||||||
y = y.transpose(self.channel_dim, -1)
|
|
||||||
return y
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def _test_activation_balancer_sign():
|
def _test_activation_balancer_sign():
|
||||||
probs = torch.arange(0, 1, 0.01)
|
probs = torch.arange(0, 1, 0.01)
|
||||||
@ -1101,41 +1010,12 @@ def _test_decorrelate():
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
def _test_join_dropout():
|
|
||||||
D = 384
|
|
||||||
x = torch.randn(30000, D)
|
|
||||||
|
|
||||||
# give it a non-unit covariance.
|
|
||||||
m = torch.randn(D, D) * (D ** -0.5)
|
|
||||||
_, S, _ = m.svd()
|
|
||||||
print("M eigs = ", S[::10])
|
|
||||||
x = torch.matmul(x, m)
|
|
||||||
|
|
||||||
|
|
||||||
for dropout_rate in [0.2, 0.1, 0.01, 0.05]:
|
|
||||||
m1 = torch.nn.Dropout(dropout_rate)
|
|
||||||
m2 = JoinDropout(D, apply_prob=1.0, dropout_rate=dropout_rate)
|
|
||||||
bypass = torch.zeros_like(x)
|
|
||||||
for mode in ['train', 'eval']:
|
|
||||||
y1 = m1(x)
|
|
||||||
for _ in range(2):
|
|
||||||
y2 = m2(bypass, x)
|
|
||||||
xmag = (x*x).mean()
|
|
||||||
y1mag = (y1*y1).mean()
|
|
||||||
cross1 = (x*y1).mean()
|
|
||||||
y2mag = (y2*y2).mean()
|
|
||||||
cross2 = (x*y2).mean()
|
|
||||||
print(f"rate={dropout_rate}, mode={mode}, xmag = {xmag}, y1mag = {y1mag}, y2mag = {y2mag}, cross1={cross1}, cross2={cross2}, ratio1={y1mag/cross1}, ratio2={y2mag/cross2}")
|
|
||||||
m1.eval()
|
|
||||||
m2.eval()
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
logging.getLogger().setLevel(logging.INFO)
|
logging.getLogger().setLevel(logging.INFO)
|
||||||
torch.set_num_threads(1)
|
torch.set_num_threads(1)
|
||||||
torch.set_num_interop_threads(1)
|
torch.set_num_interop_threads(1)
|
||||||
_test_decorrelate()
|
_test_decorrelate()
|
||||||
_test_join_dropout()
|
|
||||||
_test_gauss_proj_drop()
|
_test_gauss_proj_drop()
|
||||||
_test_activation_balancer_sign()
|
_test_activation_balancer_sign()
|
||||||
_test_activation_balancer_magnitude()
|
_test_activation_balancer_magnitude()
|
||||||
|
|||||||
@ -29,7 +29,6 @@ from scaling import (
|
|||||||
ScaledConv1d,
|
ScaledConv1d,
|
||||||
ScaledConv2d,
|
ScaledConv2d,
|
||||||
ScaledLinear,
|
ScaledLinear,
|
||||||
JoinDropout,
|
|
||||||
Decorrelate,
|
Decorrelate,
|
||||||
)
|
)
|
||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user