Do scaling a different way, so loss function is more consistent; accum stats in backward pass

This commit is contained in:
Daniel Povey 2022-06-10 14:16:44 +08:00
parent 58cbc3d961
commit ff0309947a

View File

@ -803,7 +803,6 @@ class DecorrelateFunction(torch.autograd.Function):
def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None, None, None]: def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None, None, None]:
x, old_cov = ctx.saved_tensors x, old_cov = ctx.saved_tensors
# Reshape x and x_grad to be (num_frames, num_channels) # Reshape x and x_grad to be (num_frames, num_channels)
x = x.transpose(-1, ctx.channel_dim) x = x.transpose(-1, ctx.channel_dim)
x_grad = x_grad.transpose(-1, ctx.channel_dim) x_grad = x_grad.transpose(-1, ctx.channel_dim)
@ -813,8 +812,22 @@ class DecorrelateFunction(torch.autograd.Function):
x_grad = x_grad.reshape(-1, num_channels) x_grad = x_grad.reshape(-1, num_channels)
x.requires_grad = True x.requires_grad = True
# Now, normalize the contributions of frames/pixels x to the covariance,
# to have magnitudes proportional to the norm of the gradient on that
# frame; the goal is to exclude "don't-care" frames such as padding frames from
# the computation.
x_grad_old_sqnorm = (x_grad ** 2).sum(dim=1)
x_sqnorm = (x.detach() ** 2).sum(dim=1)
x_desired_sqscale = x_grad_old_sqnorm ** 0.5 # desired scale of x in sum for cov
x_factor = (x_desired_sqscale / (x_sqnorm + ctx.eps)) ** 0.5
with torch.enable_grad(): with torch.enable_grad():
cov = _update_cov_stats(old_cov, x, ctx.beta) scaled_x = x * x_factor.unsqueeze(-1)
cov = _update_cov_stats(old_cov, scaled_x, ctx.beta)
old_cov[:] = cov # update the stats outside! This is not really
# how backprop is supposed to work, but this input
# is not differentiable..
loss = _compute_correlation_loss(cov, ctx.eps) loss = _compute_correlation_loss(cov, ctx.eps)
if random.random() < 0.01: if random.random() < 0.01:
@ -823,18 +836,14 @@ class DecorrelateFunction(torch.autograd.Function):
decorr_x_grad = x.grad decorr_x_grad = x.grad
# Now, normalize the magnitudes of the rows of the new grad
# contribution, to have magnitudes equals to ctx.scale times
# `loss ** 0.5` times the magnitude of the original grad.
decorr_x_grad_sqnorm = (decorr_x_grad ** 2).sum(dim=1)
x_grad_old_sqnorm = (x_grad ** 2).sum(dim=1)
# loss.detach().clamp(min=0.0, max=1.0) is a factor that means once # loss.detach().clamp(min=0.0, max=1.0) is a factor that means once
# the loss starts getting quite small (less than 1), we start using # the loss starts getting quite small (less than 1), we start using
# smaller derivatives. # smaller derivatives.
decorr_loss_scale = ctx.scale * loss.detach().clamp(min=0.0, max=1.0) decorr_loss_scale = ctx.scale * loss.detach().clamp(min=0.0, max=1.0)
scale = decorr_loss_scale * (x_grad_old_sqnorm / (decorr_x_grad_sqnorm + 1.0e-20)) ** 0.5 scale = decorr_loss_scale * ((x_grad ** 2).mean() / ((decorr_x_grad ** 2).mean() + 1.0e-20)) ** 0.5
decorr_x_grad = decorr_x_grad * scale.unsqueeze(-1) decorr_x_grad = decorr_x_grad * scale
x_grad = x_grad + decorr_x_grad x_grad = x_grad + decorr_x_grad
@ -902,19 +911,12 @@ class Decorrelate(torch.nn.Module):
return x return x
with torch.cuda.amp.autocast(enabled=False): with torch.cuda.amp.autocast(enabled=False):
x = x.to(torch.float32) x = x.to(torch.float32)
ans = DecorrelateFunction.apply(x, self.cov.clone(), # the function updates self.cov in its backward pass (it needs the gradient
# norm, for frame weighting).
ans = DecorrelateFunction.apply(x, self.cov,
self.scale, self.eps, self.beta, self.scale, self.eps, self.beta,
self.channel_dim) # == x. self.channel_dim) # == x.
return ans
x = x.transpose(self.channel_dim, -1)
x = x.reshape(-1, x.shape[-1])
cov = torch.matmul(x.t(), x)
with torch.no_grad():
self.cov.mul_(self.beta).add_(cov, alpha=(1-self.beta))
m = self.cov.max()
assert m == m
return ans # ans == x.