mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-18 21:44:18 +00:00
Remove Decorrelate() class
This commit is contained in:
parent
7338c60296
commit
ca7cffcb42
@ -410,188 +410,6 @@ class GaussProjDrop(torch.nn.Module):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
def _compute_correlation_loss(cov: Tensor,
|
|
||||||
eps: float) -> Tensor:
|
|
||||||
"""
|
|
||||||
Computes the correlation `loss`, which would be zero if the channel dimensions
|
|
||||||
are un-correlated, and equals num_channels if they are maximally correlated
|
|
||||||
(i.e., they all point in the same direction)
|
|
||||||
Args:
|
|
||||||
cov: Uncentered covariance matrix of shape (num_channels, num_channels),
|
|
||||||
does not have to be normalized by count.
|
|
||||||
"""
|
|
||||||
inv_sqrt_diag = (cov.diag() + eps) ** -0.5
|
|
||||||
norm_cov = cov * (inv_sqrt_diag * inv_sqrt_diag.unsqueeze(-1))
|
|
||||||
num_channels = cov.shape[0]
|
|
||||||
loss = ((norm_cov - norm_cov.diag().diag()) ** 2).sum() / num_channels
|
|
||||||
return loss
|
|
||||||
|
|
||||||
def _update_cov_stats(cov: Tensor,
|
|
||||||
x: Tensor,
|
|
||||||
beta: float) -> Tensor:
|
|
||||||
"""
|
|
||||||
Updates covariance stats as a decaying sum, returning the result.
|
|
||||||
cov: Old covariance stats, to be added to and returned, of shape
|
|
||||||
(num_channels, num_channels)
|
|
||||||
x: Tensor of features/activations, of shape (num_frames, num_channels)
|
|
||||||
beta: The decay constant for the stats, e.g. 0.8.
|
|
||||||
"""
|
|
||||||
new_cov = torch.matmul(x.t(), x)
|
|
||||||
return cov * beta + new_cov * (1-beta)
|
|
||||||
|
|
||||||
|
|
||||||
class DecorrelateFunction(torch.autograd.Function):
|
|
||||||
"""
|
|
||||||
Function object for a function that does nothing in the forward pass;
|
|
||||||
but, in the backward pass, adds derivatives that encourage the channel dimensions
|
|
||||||
to be un-correlated with each other.
|
|
||||||
This should not be used in a half-precision-enabled area, use
|
|
||||||
with torch.cuda.amp.autocast(enabled=False).
|
|
||||||
|
|
||||||
Args:
|
|
||||||
x: The input tensor, which is also the function output. It can have
|
|
||||||
arbitrary shape, but its dimension `channel_dim` (e.g. -1) will be
|
|
||||||
interpreted as the channel dimension.
|
|
||||||
old_cov: Covariance statistics from previous frames, accumulated as a decaying
|
|
||||||
sum over frames (not average), decaying by beta each time.
|
|
||||||
scale: The scale on the derivatives arising from this, expressed as
|
|
||||||
a fraction of the norm of the derivative at the output (applied per
|
|
||||||
frame). We will further scale down the derivative if the normalized
|
|
||||||
loss is less than 1.
|
|
||||||
eps: Epsilon value to prevent division by zero when estimating diagonal
|
|
||||||
covariance.
|
|
||||||
beta: Decay constant that determines how we combine stats from x with
|
|
||||||
stats from cov; e.g. 0.8.
|
|
||||||
channel_dim: The dimension/axis of x corresponding to the channel, e.g. 0, 1, 2, -1.
|
|
||||||
"""
|
|
||||||
@staticmethod
|
|
||||||
def forward(ctx, x: Tensor, old_cov: Tensor,
|
|
||||||
scale: float, eps: float, beta: float,
|
|
||||||
channel_dim: int) -> Tensor:
|
|
||||||
ctx.save_for_backward(x.detach(), old_cov.detach())
|
|
||||||
ctx.scale = scale
|
|
||||||
ctx.eps = eps
|
|
||||||
ctx.beta = beta
|
|
||||||
ctx.channel_dim = channel_dim
|
|
||||||
return x
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None, None, None]:
|
|
||||||
x, old_cov = ctx.saved_tensors
|
|
||||||
|
|
||||||
# Reshape x and x_grad to be (num_frames, num_channels)
|
|
||||||
x = x.transpose(-1, ctx.channel_dim)
|
|
||||||
x_grad = x_grad.transpose(-1, ctx.channel_dim)
|
|
||||||
num_channels = x.shape[-1]
|
|
||||||
full_shape = x.shape
|
|
||||||
x = x.reshape(-1, num_channels)
|
|
||||||
x_grad = x_grad.reshape(-1, num_channels)
|
|
||||||
x.requires_grad = True
|
|
||||||
|
|
||||||
# Now, normalize the contributions of frames/pixels x to the covariance,
|
|
||||||
# to have magnitudes proportional to the norm of the gradient on that
|
|
||||||
# frame; the goal is to exclude "don't-care" frames such as padding frames from
|
|
||||||
# the computation.
|
|
||||||
x_grad_sqrt_norm = (x_grad ** 2).sum(dim=1) ** 0.25
|
|
||||||
x_grad_sqrt_norm /= x_grad_sqrt_norm.mean()
|
|
||||||
x_grad_sqrt_norm_is_inf = (x_grad_sqrt_norm - x_grad_sqrt_norm != 0)
|
|
||||||
x_grad_sqrt_norm.masked_fill_(x_grad_sqrt_norm_is_inf, 1.0)
|
|
||||||
|
|
||||||
with torch.enable_grad():
|
|
||||||
# scale up frames with larger grads.
|
|
||||||
x_scaled = x * x_grad_sqrt_norm.unsqueeze(-1)
|
|
||||||
|
|
||||||
cov = _update_cov_stats(old_cov, x_scaled, ctx.beta)
|
|
||||||
assert old_cov.dtype != torch.float16
|
|
||||||
old_cov[:] = cov # update the stats outside! This is not really
|
|
||||||
# how backprop is supposed to work, but this input
|
|
||||||
# is not differentiable..
|
|
||||||
loss = _compute_correlation_loss(cov, ctx.eps)
|
|
||||||
assert loss.dtype == torch.float32
|
|
||||||
|
|
||||||
if random.random() < 0.05:
|
|
||||||
logging.info(f"Decorrelate: loss = {loss}")
|
|
||||||
|
|
||||||
loss.backward()
|
|
||||||
|
|
||||||
decorr_x_grad = x.grad
|
|
||||||
|
|
||||||
scale = ctx.scale * ((x_grad ** 2).mean() / ((decorr_x_grad ** 2).mean() + 1.0e-20)) ** 0.5
|
|
||||||
decorr_x_grad = decorr_x_grad * scale
|
|
||||||
|
|
||||||
x_grad = x_grad + decorr_x_grad
|
|
||||||
|
|
||||||
# reshape back to original shape
|
|
||||||
x_grad = x_grad.reshape(full_shape)
|
|
||||||
x_grad = x_grad.transpose(-1, ctx.channel_dim)
|
|
||||||
|
|
||||||
return x_grad, None, None, None, None, None
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class Decorrelate(torch.nn.Module):
|
|
||||||
"""
|
|
||||||
This module does nothing in the forward pass, but in the backward pass, modifies
|
|
||||||
the derivatives in such a way as to encourage the dimensions of its input to become
|
|
||||||
decorrelated.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
num_channels: The number of channels, e.g. 256.
|
|
||||||
apply_prob_decay: The probability with which we apply this each time, in
|
|
||||||
training mode, will decay as apply_prob_decay/(apply_prob_decay + step).
|
|
||||||
scale: This number determines the scale of the gradient contribution from
|
|
||||||
this module, relative to whatever the gradient was before;
|
|
||||||
this is applied per frame or pixel, by scaling gradients.
|
|
||||||
eps: An epsilon used to prevent division by zero.
|
|
||||||
beta: A value 0 < beta < 1 that controls decay of covariance stats
|
|
||||||
channel_dim: The dimension of the input corresponding to the channel, e.g.
|
|
||||||
-1, 0, 1, 2.
|
|
||||||
|
|
||||||
"""
|
|
||||||
def __init__(self,
|
|
||||||
num_channels: int,
|
|
||||||
scale: float = 0.1,
|
|
||||||
apply_prob_decay: int = 1000,
|
|
||||||
eps: float = 1.0e-05,
|
|
||||||
beta: float = 0.95,
|
|
||||||
channel_dim: int = -1):
|
|
||||||
super(Decorrelate, self).__init__()
|
|
||||||
self.scale = scale
|
|
||||||
self.apply_prob_decay = apply_prob_decay
|
|
||||||
self.eps = eps
|
|
||||||
self.beta = beta
|
|
||||||
self.channel_dim = channel_dim
|
|
||||||
|
|
||||||
self.register_buffer('cov', torch.zeros(num_channels, num_channels))
|
|
||||||
# step_buf is a copy of step, included so it will be loaded/saved with
|
|
||||||
# the model.
|
|
||||||
self.register_buffer('step_buf', torch.tensor(0.0))
|
|
||||||
self.step = 0
|
|
||||||
|
|
||||||
|
|
||||||
def load_state_dict(self, *args, **kwargs):
|
|
||||||
super(Decorrelate, self).load_state_dict(*args, **kwargs)
|
|
||||||
self.step = int(self.step_buf.item())
|
|
||||||
|
|
||||||
|
|
||||||
def forward(self, x: Tensor) -> Tensor:
|
|
||||||
if not self.training:
|
|
||||||
return x
|
|
||||||
else:
|
|
||||||
apply_prob = self.apply_prob_decay / (self.step + self.apply_prob_decay)
|
|
||||||
self.step += 1
|
|
||||||
self.step_buf.fill_(float(self.step))
|
|
||||||
if random.random() > apply_prob:
|
|
||||||
return x
|
|
||||||
with torch.cuda.amp.autocast(enabled=False):
|
|
||||||
x = x.to(torch.float32)
|
|
||||||
# the function updates self.cov in its backward pass (it needs the gradient
|
|
||||||
# norm, for frame weighting).
|
|
||||||
ans = DecorrelateFunction.apply(x, self.cov,
|
|
||||||
self.scale, self.eps, self.beta,
|
|
||||||
self.channel_dim) # == x.
|
|
||||||
return ans
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def _test_activation_balancer_sign():
|
def _test_activation_balancer_sign():
|
||||||
@ -688,34 +506,11 @@ def _test_gauss_proj_drop():
|
|||||||
m2.eval()
|
m2.eval()
|
||||||
|
|
||||||
|
|
||||||
def _test_decorrelate():
|
|
||||||
D = 384
|
|
||||||
x = torch.randn(30000, D)
|
|
||||||
# give it a non-unit covariance.
|
|
||||||
m = torch.randn(D, D) * (D ** -0.5)
|
|
||||||
_, S, _ = m.svd()
|
|
||||||
print("M eigs = ", S[::10])
|
|
||||||
x = torch.matmul(x, m)
|
|
||||||
|
|
||||||
|
|
||||||
# check that class Decorrelate does not crash when running..
|
|
||||||
decorrelate = Decorrelate(D)
|
|
||||||
x.requires_grad = True
|
|
||||||
y = decorrelate(x)
|
|
||||||
y.sum().backward()
|
|
||||||
|
|
||||||
decorrelate2 = Decorrelate(D)
|
|
||||||
decorrelate2.load_state_dict(decorrelate.state_dict())
|
|
||||||
assert decorrelate2.step == decorrelate.step
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
logging.getLogger().setLevel(logging.INFO)
|
logging.getLogger().setLevel(logging.INFO)
|
||||||
torch.set_num_threads(1)
|
torch.set_num_threads(1)
|
||||||
torch.set_num_interop_threads(1)
|
torch.set_num_interop_threads(1)
|
||||||
_test_decorrelate()
|
|
||||||
_test_gauss_proj_drop()
|
_test_gauss_proj_drop()
|
||||||
_test_activation_balancer_sign()
|
_test_activation_balancer_sign()
|
||||||
_test_activation_balancer_magnitude()
|
_test_activation_balancer_magnitude()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user