mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-18 21:44:18 +00:00
Change beta from 0.8 to 0.95
This commit is contained in:
parent
6ed181595b
commit
5fb64a59b8
@ -812,9 +812,7 @@ class DecorrelateFunction(torch.autograd.Function):
|
|||||||
# is not differentiable..
|
# is not differentiable..
|
||||||
loss = _compute_correlation_loss(cov, ctx.eps)
|
loss = _compute_correlation_loss(cov, ctx.eps)
|
||||||
assert loss.dtype == torch.float32
|
assert loss.dtype == torch.float32
|
||||||
#print(f"x_sqnorm mean = {x_sqnorm.mean().item()}, x_sqnorm_mean={x_sqnorm.mean().item()}, x_desired_sqscale_sum={x_desired_sqscale.sum()}, x_grad_old_sqnorm mean = {x_grad_old_sqnorm.mean().item()}, x**2_mean = {(x**2).mean().item()}, scaled_x**2_mean = {(scaled_x**2).mean().item()}, (cov-abs-mean)={cov.abs().mean().item()}, old_cov_abs_mean={old_cov.abs().mean().item()}, loss = {loss}")
|
|
||||||
|
|
||||||
#if random.random() < 0.01:
|
|
||||||
if random.random() < 0.05:
|
if random.random() < 0.05:
|
||||||
logging.info(f"Decorrelate: loss = {loss}")
|
logging.info(f"Decorrelate: loss = {loss}")
|
||||||
|
|
||||||
@ -858,7 +856,7 @@ class Decorrelate(torch.nn.Module):
|
|||||||
scale: float = 0.1,
|
scale: float = 0.1,
|
||||||
apply_steps: int = 1000,
|
apply_steps: int = 1000,
|
||||||
eps: float = 1.0e-05,
|
eps: float = 1.0e-05,
|
||||||
beta: float = 0.8,
|
beta: float = 0.95,
|
||||||
channel_dim: int = -1):
|
channel_dim: int = -1):
|
||||||
super(Decorrelate, self).__init__()
|
super(Decorrelate, self).__init__()
|
||||||
self.scale = scale
|
self.scale = scale
|
||||||
|
Loading…
x
Reference in New Issue
Block a user