mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-19 05:54:20 +00:00
Do scaling a different way, so loss function is more consistent; accum stats in backward pass
This commit is contained in:
parent
58cbc3d961
commit
ff0309947a
@ -803,7 +803,6 @@ class DecorrelateFunction(torch.autograd.Function):
|
|||||||
def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None, None, None]:
|
def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None, None, None]:
|
||||||
x, old_cov = ctx.saved_tensors
|
x, old_cov = ctx.saved_tensors
|
||||||
|
|
||||||
|
|
||||||
# Reshape x and x_grad to be (num_frames, num_channels)
|
# Reshape x and x_grad to be (num_frames, num_channels)
|
||||||
x = x.transpose(-1, ctx.channel_dim)
|
x = x.transpose(-1, ctx.channel_dim)
|
||||||
x_grad = x_grad.transpose(-1, ctx.channel_dim)
|
x_grad = x_grad.transpose(-1, ctx.channel_dim)
|
||||||
@ -813,8 +812,22 @@ class DecorrelateFunction(torch.autograd.Function):
|
|||||||
x_grad = x_grad.reshape(-1, num_channels)
|
x_grad = x_grad.reshape(-1, num_channels)
|
||||||
x.requires_grad = True
|
x.requires_grad = True
|
||||||
|
|
||||||
|
# Now, normalize the contributions of frames/pixels x to the covariance,
|
||||||
|
# to have magnitudes proportional to the norm of the gradient on that
|
||||||
|
# frame; the goal is to exclude "don't-care" frames such as padding frames from
|
||||||
|
# the computation.
|
||||||
|
|
||||||
|
x_grad_old_sqnorm = (x_grad ** 2).sum(dim=1)
|
||||||
|
x_sqnorm = (x.detach() ** 2).sum(dim=1)
|
||||||
|
x_desired_sqscale = x_grad_old_sqnorm ** 0.5 # desired scale of x in sum for cov
|
||||||
|
x_factor = (x_desired_sqscale / (x_sqnorm + ctx.eps)) ** 0.5
|
||||||
|
|
||||||
with torch.enable_grad():
|
with torch.enable_grad():
|
||||||
cov = _update_cov_stats(old_cov, x, ctx.beta)
|
scaled_x = x * x_factor.unsqueeze(-1)
|
||||||
|
cov = _update_cov_stats(old_cov, scaled_x, ctx.beta)
|
||||||
|
old_cov[:] = cov # update the stats outside! This is not really
|
||||||
|
# how backprop is supposed to work, but this input
|
||||||
|
# is not differentiable..
|
||||||
loss = _compute_correlation_loss(cov, ctx.eps)
|
loss = _compute_correlation_loss(cov, ctx.eps)
|
||||||
|
|
||||||
if random.random() < 0.01:
|
if random.random() < 0.01:
|
||||||
@ -823,18 +836,14 @@ class DecorrelateFunction(torch.autograd.Function):
|
|||||||
|
|
||||||
decorr_x_grad = x.grad
|
decorr_x_grad = x.grad
|
||||||
|
|
||||||
# Now, normalize the magnitudes of the rows of the new grad
|
|
||||||
# contribution, to have magnitudes equals to ctx.scale times
|
|
||||||
# `loss ** 0.5` times the magnitude of the original grad.
|
|
||||||
decorr_x_grad_sqnorm = (decorr_x_grad ** 2).sum(dim=1)
|
|
||||||
x_grad_old_sqnorm = (x_grad ** 2).sum(dim=1)
|
|
||||||
|
|
||||||
# loss.detach().clamp(min=0.0, max=1.0) is a factor that means once
|
# loss.detach().clamp(min=0.0, max=1.0) is a factor that means once
|
||||||
# the loss starts getting quite small (less than 1), we start using
|
# the loss starts getting quite small (less than 1), we start using
|
||||||
# smaller derivatives.
|
# smaller derivatives.
|
||||||
|
|
||||||
|
|
||||||
decorr_loss_scale = ctx.scale * loss.detach().clamp(min=0.0, max=1.0)
|
decorr_loss_scale = ctx.scale * loss.detach().clamp(min=0.0, max=1.0)
|
||||||
scale = decorr_loss_scale * (x_grad_old_sqnorm / (decorr_x_grad_sqnorm + 1.0e-20)) ** 0.5
|
scale = decorr_loss_scale * ((x_grad ** 2).mean() / ((decorr_x_grad ** 2).mean() + 1.0e-20)) ** 0.5
|
||||||
decorr_x_grad = decorr_x_grad * scale.unsqueeze(-1)
|
decorr_x_grad = decorr_x_grad * scale
|
||||||
|
|
||||||
x_grad = x_grad + decorr_x_grad
|
x_grad = x_grad + decorr_x_grad
|
||||||
|
|
||||||
@ -902,19 +911,12 @@ class Decorrelate(torch.nn.Module):
|
|||||||
return x
|
return x
|
||||||
with torch.cuda.amp.autocast(enabled=False):
|
with torch.cuda.amp.autocast(enabled=False):
|
||||||
x = x.to(torch.float32)
|
x = x.to(torch.float32)
|
||||||
ans = DecorrelateFunction.apply(x, self.cov.clone(),
|
# the function updates self.cov in its backward pass (it needs the gradient
|
||||||
|
# norm, for frame weighting).
|
||||||
|
ans = DecorrelateFunction.apply(x, self.cov,
|
||||||
self.scale, self.eps, self.beta,
|
self.scale, self.eps, self.beta,
|
||||||
self.channel_dim) # == x.
|
self.channel_dim) # == x.
|
||||||
|
return ans
|
||||||
x = x.transpose(self.channel_dim, -1)
|
|
||||||
x = x.reshape(-1, x.shape[-1])
|
|
||||||
cov = torch.matmul(x.t(), x)
|
|
||||||
with torch.no_grad():
|
|
||||||
self.cov.mul_(self.beta).add_(cov, alpha=(1-self.beta))
|
|
||||||
m = self.cov.max()
|
|
||||||
assert m == m
|
|
||||||
|
|
||||||
return ans # ans == x.
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user