mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-19 05:54:20 +00:00
Simplified gradient scaling [no scaling]; only use 1k first iters; beta =0.8
This commit is contained in:
parent
cecd52155c
commit
9d4633facf
@ -796,32 +796,32 @@ class DecorrelateFunction(torch.autograd.Function):
|
|||||||
# to have magnitudes proportional to the norm of the gradient on that
|
# to have magnitudes proportional to the norm of the gradient on that
|
||||||
# frame; the goal is to exclude "don't-care" frames such as padding frames from
|
# frame; the goal is to exclude "don't-care" frames such as padding frames from
|
||||||
# the computation.
|
# the computation.
|
||||||
|
#x_grad_old_sqnorm = (x_grad ** 2).sum(dim=1)
|
||||||
x_grad_old_sqnorm = (x_grad ** 2).sum(dim=1)
|
|
||||||
|
|
||||||
with torch.enable_grad():
|
with torch.enable_grad():
|
||||||
x_sqnorm = (x ** 2).sum(dim=1)
|
#x_sqnorm = (x ** 2).sum(dim=1)
|
||||||
|
|
||||||
x_desired_sqscale = x_grad_old_sqnorm ** 0.5 # desired scale of x*x in sum for cov
|
#x_desired_sqscale = x_grad_old_sqnorm ** 0.5 # desired scale of x*x in sum for cov
|
||||||
x_desired_sqscale /= (x_desired_sqscale.sum() + 1.0e-20) # sum-to-one scales
|
#x_desired_sqscale /= (x_desired_sqscale.sum() + 1.0e-20) # sum-to-one scales
|
||||||
x_desired_sqscale_is_inf = (x_desired_sqscale - x_desired_sqscale != 0)
|
#x_desired_sqscale_is_inf = (x_desired_sqscale - x_desired_sqscale != 0)
|
||||||
# if grads are inf, use equal scales for frames (can happen due to GradScaler, in half
|
# if grads are inf, use equal scales for frames (can happen due to GradScaler, in half
|
||||||
# precision)
|
# precision)
|
||||||
x_desired_sqscale.masked_fill_(x_desired_sqscale_is_inf, 1.0 / x_desired_sqscale.numel())
|
#x_desired_sqscale.masked_fill_(x_desired_sqscale_is_inf, 1.0 / x_desired_sqscale.numel())
|
||||||
|
|
||||||
x_factor = (x_desired_sqscale * num_channels / (x_sqnorm + ctx.eps)) ** 0.5
|
#x_factor = (x_desired_sqscale * num_channels / (x_sqnorm + ctx.eps)) ** 0.5
|
||||||
|
|
||||||
scaled_x = x * x_factor.unsqueeze(-1)
|
#scaled_x = x * x_factor.unsqueeze(-1)
|
||||||
cov = _update_cov_stats(old_cov, scaled_x, ctx.beta)
|
cov = _update_cov_stats(old_cov, x, ctx.beta)
|
||||||
assert old_cov.dtype != torch.float16
|
assert old_cov.dtype != torch.float16
|
||||||
old_cov[:] = cov # update the stats outside! This is not really
|
old_cov[:] = cov # update the stats outside! This is not really
|
||||||
# how backprop is supposed to work, but this input
|
# how backprop is supposed to work, but this input
|
||||||
# is not differentiable..
|
# is not differentiable..
|
||||||
loss = _compute_correlation_loss(cov, ctx.eps)
|
loss = _compute_correlation_loss(cov, ctx.eps)
|
||||||
|
assert loss.dtype == torch.float32
|
||||||
#print(f"x_sqnorm mean = {x_sqnorm.mean().item()}, x_sqnorm_mean={x_sqnorm.mean().item()}, x_desired_sqscale_sum={x_desired_sqscale.sum()}, x_grad_old_sqnorm mean = {x_grad_old_sqnorm.mean().item()}, x**2_mean = {(x**2).mean().item()}, scaled_x**2_mean = {(scaled_x**2).mean().item()}, (cov-abs-mean)={cov.abs().mean().item()}, old_cov_abs_mean={old_cov.abs().mean().item()}, loss = {loss}")
|
#print(f"x_sqnorm mean = {x_sqnorm.mean().item()}, x_sqnorm_mean={x_sqnorm.mean().item()}, x_desired_sqscale_sum={x_desired_sqscale.sum()}, x_grad_old_sqnorm mean = {x_grad_old_sqnorm.mean().item()}, x**2_mean = {(x**2).mean().item()}, scaled_x**2_mean = {(scaled_x**2).mean().item()}, (cov-abs-mean)={cov.abs().mean().item()}, old_cov_abs_mean={old_cov.abs().mean().item()}, loss = {loss}")
|
||||||
|
|
||||||
if random.random() < 0.01:
|
#if random.random() < 0.01:
|
||||||
|
if random.random() < 0.05:
|
||||||
logging.info(f"Decorrelate: loss = {loss}")
|
logging.info(f"Decorrelate: loss = {loss}")
|
||||||
|
|
||||||
loss.backward()
|
loss.backward()
|
||||||
@ -862,9 +862,9 @@ class Decorrelate(torch.nn.Module):
|
|||||||
def __init__(self,
|
def __init__(self,
|
||||||
num_channels: int,
|
num_channels: int,
|
||||||
scale: float = 0.1,
|
scale: float = 0.1,
|
||||||
apply_steps: int = 3000,
|
apply_steps: int = 1000,
|
||||||
eps: float = 1.0e-05,
|
eps: float = 1.0e-05,
|
||||||
beta: float = 0.95,
|
beta: float = 0.8,
|
||||||
channel_dim: int = -1):
|
channel_dim: int = -1):
|
||||||
super(Decorrelate, self).__init__()
|
super(Decorrelate, self).__init__()
|
||||||
self.scale = scale
|
self.scale = scale
|
||||||
|
Loading…
x
Reference in New Issue
Block a user