diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py index ca8617e22..a3a88d83a 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py @@ -796,32 +796,32 @@ class DecorrelateFunction(torch.autograd.Function): # to have magnitudes proportional to the norm of the gradient on that # frame; the goal is to exclude "don't-care" frames such as padding frames from # the computation. - - x_grad_old_sqnorm = (x_grad ** 2).sum(dim=1) + #x_grad_old_sqnorm = (x_grad ** 2).sum(dim=1) with torch.enable_grad(): - x_sqnorm = (x ** 2).sum(dim=1) + #x_sqnorm = (x ** 2).sum(dim=1) - x_desired_sqscale = x_grad_old_sqnorm ** 0.5 # desired scale of x*x in sum for cov - x_desired_sqscale /= (x_desired_sqscale.sum() + 1.0e-20) # sum-to-one scales - x_desired_sqscale_is_inf = (x_desired_sqscale - x_desired_sqscale != 0) + #x_desired_sqscale = x_grad_old_sqnorm ** 0.5 # desired scale of x*x in sum for cov + #x_desired_sqscale /= (x_desired_sqscale.sum() + 1.0e-20) # sum-to-one scales + #x_desired_sqscale_is_inf = (x_desired_sqscale - x_desired_sqscale != 0) # if grads are inf, use equal scales for frames (can happen due to GradScaler, in half # precision) - x_desired_sqscale.masked_fill_(x_desired_sqscale_is_inf, 1.0 / x_desired_sqscale.numel()) + #x_desired_sqscale.masked_fill_(x_desired_sqscale_is_inf, 1.0 / x_desired_sqscale.numel()) - x_factor = (x_desired_sqscale * num_channels / (x_sqnorm + ctx.eps)) ** 0.5 + #x_factor = (x_desired_sqscale * num_channels / (x_sqnorm + ctx.eps)) ** 0.5 - scaled_x = x * x_factor.unsqueeze(-1) - cov = _update_cov_stats(old_cov, scaled_x, ctx.beta) + #scaled_x = x * x_factor.unsqueeze(-1) + cov = _update_cov_stats(old_cov, x, ctx.beta) assert old_cov.dtype != torch.float16 old_cov[:] = cov # update the stats outside! This is not really # how backprop is supposed to work, but this input # is not differentiable.. loss = _compute_correlation_loss(cov, ctx.eps) - + assert loss.dtype == torch.float32 #print(f"x_sqnorm mean = {x_sqnorm.mean().item()}, x_sqnorm_mean={x_sqnorm.mean().item()}, x_desired_sqscale_sum={x_desired_sqscale.sum()}, x_grad_old_sqnorm mean = {x_grad_old_sqnorm.mean().item()}, x**2_mean = {(x**2).mean().item()}, scaled_x**2_mean = {(scaled_x**2).mean().item()}, (cov-abs-mean)={cov.abs().mean().item()}, old_cov_abs_mean={old_cov.abs().mean().item()}, loss = {loss}") - if random.random() < 0.01: + #if random.random() < 0.01: + if random.random() < 0.05: logging.info(f"Decorrelate: loss = {loss}") loss.backward() @@ -862,9 +862,9 @@ class Decorrelate(torch.nn.Module): def __init__(self, num_channels: int, scale: float = 0.1, - apply_steps: int = 3000, + apply_steps: int = 1000, eps: float = 1.0e-05, - beta: float = 0.95, + beta: float = 0.8, channel_dim: int = -1): super(Decorrelate, self).__init__() self.scale = scale