mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-18 21:44:18 +00:00
Bug fix RE eps*eps; add/tune diagnostics
This commit is contained in:
parent
2615e48779
commit
932cedce59
@ -280,15 +280,16 @@ class NeutralGradient(Optimizer):
|
|||||||
if random.random() < 0.02:
|
if random.random() < 0.02:
|
||||||
print(f"grad_scale mean = {grad_scale.mean()}, shape = {p.shape}")
|
print(f"grad_scale mean = {grad_scale.mean()}, shape = {p.shape}")
|
||||||
|
|
||||||
cur_grad = grad * grad_scale
|
cur_grad = grad
|
||||||
|
cur_grad = cur_grad * grad_scale
|
||||||
cur_grad = self._precondition_grad(cur_grad, state)
|
cur_grad = self._precondition_grad(cur_grad, state)
|
||||||
cur_grad *= grad_scale
|
cur_grad *= grad_scale
|
||||||
|
|
||||||
if True: # testing
|
if random.random() < 0.004:
|
||||||
# in principle, the cur_grad is supposed to have the same rms as params, on average.
|
# in principle, the cur_grad is supposed to have the same rms as params, on average.
|
||||||
cur_grad_rms = (cur_grad**2).mean().sqrt()
|
cur_grad_rms = (cur_grad**2).mean().sqrt()
|
||||||
param_rms = (p**2).mean().sqrt()
|
param_rms = (p**2).mean().sqrt()
|
||||||
#print(f"cur_grad_rms={cur_grad_rms}, param_rms={param_rms}")
|
print(f"cur_grad_rms={cur_grad_rms}, param_rms={param_rms}")
|
||||||
|
|
||||||
if random.random() < 0.1:
|
if random.random() < 0.1:
|
||||||
prod = (grad*cur_grad).mean()
|
prod = (grad*cur_grad).mean()
|
||||||
@ -436,7 +437,7 @@ class NeutralGradient(Optimizer):
|
|||||||
proj[:] = self._estimate_proj(grad_cov_smoothed,
|
proj[:] = self._estimate_proj(grad_cov_smoothed,
|
||||||
param_cov_smoothed)
|
param_cov_smoothed)
|
||||||
|
|
||||||
state["ref_exp_avg_sq"][:] = ref_exp_avg_sq
|
state["ref_exp_avg_sq"][:] = ref_exp_avg_sq + eps*eps
|
||||||
|
|
||||||
def _get_this_beta3(self, beta3: float, numel: int, size: int):
|
def _get_this_beta3(self, beta3: float, numel: int, size: int):
|
||||||
"""
|
"""
|
||||||
|
Loading…
x
Reference in New Issue
Block a user