mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-19 05:54:20 +00:00
Apply x row scaling with grad
This commit is contained in:
parent
86c2d0fcc0
commit
cecd52155c
@ -713,25 +713,6 @@ class GaussProjDrop(torch.nn.Module):
|
|||||||
x = (x_next * self.rand_scale + x_bypass)
|
x = (x_next * self.rand_scale + x_bypass)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
class PseudoNormalizeFunction(torch.autograd.Function):
|
|
||||||
"""
|
|
||||||
Function object that is the identity function in the forward pass; and, in the
|
|
||||||
backward pass, removes the component of the derivative in the direction of x itself
|
|
||||||
(as if it had gone through some kind of normalization layer
|
|
||||||
"""
|
|
||||||
@staticmethod
|
|
||||||
def forward(ctx, x: Tensor) -> Tensor:
|
|
||||||
ctx.save_for_backward(x)
|
|
||||||
return x
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def backward(ctx, x_grad: Tensor) -> Tensor:
|
|
||||||
x, = ctx.saved_tensors
|
|
||||||
eps = 1.0e-20
|
|
||||||
x_sumsq = (x**2).sum() + eps
|
|
||||||
grad_x_sum = (x_grad * x).sum()
|
|
||||||
return x_grad - x * (grad_x_sum / x_sumsq)
|
|
||||||
|
|
||||||
|
|
||||||
def _compute_correlation_loss(cov: Tensor,
|
def _compute_correlation_loss(cov: Tensor,
|
||||||
eps: float) -> Tensor:
|
eps: float) -> Tensor:
|
||||||
@ -759,7 +740,6 @@ def _update_cov_stats(cov: Tensor,
|
|||||||
x: Tensor of features/activations, of shape (num_frames, num_channels)
|
x: Tensor of features/activations, of shape (num_frames, num_channels)
|
||||||
beta: The decay constant for the stats, e.g. 0.8.
|
beta: The decay constant for the stats, e.g. 0.8.
|
||||||
"""
|
"""
|
||||||
x = PseudoNormalizeFunction.apply(x)
|
|
||||||
new_cov = torch.matmul(x.t(), x)
|
new_cov = torch.matmul(x.t(), x)
|
||||||
return cov * beta + new_cov * (1-beta)
|
return cov * beta + new_cov * (1-beta)
|
||||||
|
|
||||||
@ -818,18 +798,19 @@ class DecorrelateFunction(torch.autograd.Function):
|
|||||||
# the computation.
|
# the computation.
|
||||||
|
|
||||||
x_grad_old_sqnorm = (x_grad ** 2).sum(dim=1)
|
x_grad_old_sqnorm = (x_grad ** 2).sum(dim=1)
|
||||||
x_sqnorm = (x.detach() ** 2).sum(dim=1)
|
|
||||||
|
|
||||||
x_desired_sqscale = x_grad_old_sqnorm ** 0.5 # desired scale of x*x in sum for cov
|
|
||||||
x_desired_sqscale /= (x_desired_sqscale.sum() + 1.0e-20) # sum-to-one scales
|
|
||||||
x_desired_sqscale_is_inf = (x_desired_sqscale - x_desired_sqscale != 0)
|
|
||||||
# if grads are inf, use equal scales for frames (can happen due to GradScaler, in half
|
|
||||||
# precision)
|
|
||||||
x_desired_sqscale.masked_fill_(x_desired_sqscale_is_inf, 1.0 / x_desired_sqscale.numel())
|
|
||||||
|
|
||||||
x_factor = (x_desired_sqscale * num_channels / (x_sqnorm + ctx.eps)) ** 0.5
|
|
||||||
|
|
||||||
with torch.enable_grad():
|
with torch.enable_grad():
|
||||||
|
x_sqnorm = (x ** 2).sum(dim=1)
|
||||||
|
|
||||||
|
x_desired_sqscale = x_grad_old_sqnorm ** 0.5 # desired scale of x*x in sum for cov
|
||||||
|
x_desired_sqscale /= (x_desired_sqscale.sum() + 1.0e-20) # sum-to-one scales
|
||||||
|
x_desired_sqscale_is_inf = (x_desired_sqscale - x_desired_sqscale != 0)
|
||||||
|
# if grads are inf, use equal scales for frames (can happen due to GradScaler, in half
|
||||||
|
# precision)
|
||||||
|
x_desired_sqscale.masked_fill_(x_desired_sqscale_is_inf, 1.0 / x_desired_sqscale.numel())
|
||||||
|
|
||||||
|
x_factor = (x_desired_sqscale * num_channels / (x_sqnorm + ctx.eps)) ** 0.5
|
||||||
|
|
||||||
scaled_x = x * x_factor.unsqueeze(-1)
|
scaled_x = x * x_factor.unsqueeze(-1)
|
||||||
cov = _update_cov_stats(old_cov, scaled_x, ctx.beta)
|
cov = _update_cov_stats(old_cov, scaled_x, ctx.beta)
|
||||||
assert old_cov.dtype != torch.float16
|
assert old_cov.dtype != torch.float16
|
||||||
@ -1014,13 +995,6 @@ def _test_gauss_proj_drop():
|
|||||||
m1.eval()
|
m1.eval()
|
||||||
m2.eval()
|
m2.eval()
|
||||||
|
|
||||||
def _test_pseudo_normalize():
|
|
||||||
x = torch.randn(3, 4)
|
|
||||||
x.requires_grad = True
|
|
||||||
y = PseudoNormalizeFunction.apply(x)
|
|
||||||
l = (y**2).sum()
|
|
||||||
l.backward()
|
|
||||||
assert (x.grad * x).sum().abs() < 0.1
|
|
||||||
|
|
||||||
def _test_decorrelate():
|
def _test_decorrelate():
|
||||||
D = 384
|
D = 384
|
||||||
@ -1049,7 +1023,6 @@ if __name__ == "__main__":
|
|||||||
logging.getLogger().setLevel(logging.INFO)
|
logging.getLogger().setLevel(logging.INFO)
|
||||||
torch.set_num_threads(1)
|
torch.set_num_threads(1)
|
||||||
torch.set_num_interop_threads(1)
|
torch.set_num_interop_threads(1)
|
||||||
_test_pseudo_normalize()
|
|
||||||
_test_decorrelate()
|
_test_decorrelate()
|
||||||
_test_gauss_proj_drop()
|
_test_gauss_proj_drop()
|
||||||
_test_activation_balancer_sign()
|
_test_activation_balancer_sign()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user