mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-19 05:54:20 +00:00
Fix issue with cov scale
This commit is contained in:
parent
c671e213fc
commit
eeb95ed502
@ -819,14 +819,15 @@ class DecorrelateFunction(torch.autograd.Function):
|
||||
|
||||
x_grad_old_sqnorm = (x_grad ** 2).sum(dim=1)
|
||||
x_sqnorm = (x.detach() ** 2).sum(dim=1)
|
||||
x_desired_sqscale = x_grad_old_sqnorm ** 0.5 # desired scale of x in sum for cov
|
||||
|
||||
x_desired_sqscale = x_grad_old_sqnorm ** 0.5 # desired scale of x*x in sum for cov
|
||||
x_desired_sqscale /= (x_desired_sqscale.sum() + 1.0e-20) # sum-to-one scales
|
||||
x_desired_sqscale_is_inf = (x_desired_sqscale - x_desired_sqscale != 0)
|
||||
# if grads are inf, use equal scales for frames (can happen due to GradScaler, in half
|
||||
# precision)
|
||||
x_desired_sqscale.masked_fill_(x_desired_sqscale_is_inf, 1.0 / x_desired_sqscale.numel())
|
||||
|
||||
x_factor = (x_desired_sqscale / (x_sqnorm + ctx.eps)) ** 0.5
|
||||
x_factor = (x_desired_sqscale * num_channels / (x_sqnorm + ctx.eps)) ** 0.5
|
||||
|
||||
with torch.enable_grad():
|
||||
scaled_x = x * x_factor.unsqueeze(-1)
|
||||
@ -837,6 +838,8 @@ class DecorrelateFunction(torch.autograd.Function):
|
||||
# is not differentiable..
|
||||
loss = _compute_correlation_loss(cov, ctx.eps)
|
||||
|
||||
#print(f"x_sqnorm mean = {x_sqnorm.mean().item()}, x_sqnorm_mean={x_sqnorm.mean().item()}, x_desired_sqscale_sum={x_desired_sqscale.sum()}, x_grad_old_sqnorm mean = {x_grad_old_sqnorm.mean().item()}, x**2_mean = {(x**2).mean().item()}, scaled_x**2_mean = {(scaled_x**2).mean().item()}, (cov-abs-mean)={cov.abs().mean().item()}, old_cov_abs_mean={old_cov.abs().mean().item()}, loss = {loss}")
|
||||
|
||||
if random.random() < 0.01:
|
||||
logging.info(f"Decorrelate: loss = {loss}")
|
||||
|
||||
@ -1025,7 +1028,7 @@ def _test_pseudo_normalize():
|
||||
x = torch.randn(3, 4)
|
||||
x.requires_grad = True
|
||||
y = PseudoNormalizeFunction.apply(x)
|
||||
l = y.sin().sum()
|
||||
l = (y**2).sum()
|
||||
l.backward()
|
||||
assert (x.grad * x).sum().abs() < 0.1
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user