From 5fb64a59b854e042f4e792a38bfb443d150c84f1 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 10 Jun 2022 19:05:04 +0800 Subject: [PATCH] Change beta from 0.8 to 0.95 --- egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py index c7c309e21..2d1b83e33 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py @@ -812,9 +812,7 @@ class DecorrelateFunction(torch.autograd.Function): # is not differentiable.. loss = _compute_correlation_loss(cov, ctx.eps) assert loss.dtype == torch.float32 - #print(f"x_sqnorm mean = {x_sqnorm.mean().item()}, x_sqnorm_mean={x_sqnorm.mean().item()}, x_desired_sqscale_sum={x_desired_sqscale.sum()}, x_grad_old_sqnorm mean = {x_grad_old_sqnorm.mean().item()}, x**2_mean = {(x**2).mean().item()}, scaled_x**2_mean = {(scaled_x**2).mean().item()}, (cov-abs-mean)={cov.abs().mean().item()}, old_cov_abs_mean={old_cov.abs().mean().item()}, loss = {loss}") - #if random.random() < 0.01: if random.random() < 0.05: logging.info(f"Decorrelate: loss = {loss}") @@ -858,7 +856,7 @@ class Decorrelate(torch.nn.Module): scale: float = 0.1, apply_steps: int = 1000, eps: float = 1.0e-05, - beta: float = 0.8, + beta: float = 0.95, channel_dim: int = -1): super(Decorrelate, self).__init__() self.scale = scale