mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-18 21:44:18 +00:00
Improve speed by messing with configuration, removing asserts.
This commit is contained in:
parent
a41c4b6c9b
commit
d1e96afce2
@ -265,8 +265,8 @@ class NeutralGradient(Optimizer):
|
||||
scalar_exp_avg_sq = state["scalar_exp_avg_sq"]
|
||||
grad_sq = (grad**2).mean()
|
||||
conditioned_grad_sq = (conditioned_grad**2).mean()
|
||||
assert grad_sq - grad_sq == 0
|
||||
assert conditioned_grad_sq - conditioned_grad_sq == 0
|
||||
#assert grad_sq - grad_sq == 0
|
||||
#assert conditioned_grad_sq - conditioned_grad_sq == 0
|
||||
scalar_exp_avg_sq.mul_(beta2).add_(grad_sq, alpha=(1-beta2))
|
||||
bias_correction2 = 1 - beta2 ** (step + 1)
|
||||
avg_grad_sq = scalar_exp_avg_sq / bias_correction2
|
||||
@ -285,7 +285,7 @@ class NeutralGradient(Optimizer):
|
||||
delta.add_(this_delta, alpha=alpha)
|
||||
if random.random() < 0.005:
|
||||
print(f"Delta rms = {(delta**2).mean().item()}, shape = {delta.shape}")
|
||||
assert delta.abs().max() < 10.0
|
||||
#assert delta.abs().max() < 10.0
|
||||
p.add_(delta)
|
||||
if step % 10 == 0:
|
||||
p.clamp_(min=-param_max, max=param_max)
|
||||
|
Loading…
x
Reference in New Issue
Block a user