mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-19 05:54:20 +00:00
Some debug stuff, modify cond_eps for param_cov
This commit is contained in:
parent
932cedce59
commit
0679b363b0
@ -261,12 +261,16 @@ class NeutralGradient(Optimizer):
|
||||
if step % stats_period == 0:
|
||||
self._accumulate_per_dim_stats(grad, state, beta3, eps)
|
||||
|
||||
if step % estimate_period == 0 or step in [50, 200, 400]:
|
||||
if step % estimate_period == 0 or step in [25, 50, 200, 400]:
|
||||
self._estimate(p, state, beta3, max_size,
|
||||
stats_period, estimate_period,
|
||||
eps, param_eps,
|
||||
cond_eps, min_diag_smooth)
|
||||
|
||||
# TEMP!! Override the setting inside _estimate.
|
||||
#state["ref_exp_avg_sq"][:] = ((exp_avg_sq/bias_correction2 + eps*eps) *
|
||||
# state["ref_exp_avg_sq"]).sqrt()
|
||||
|
||||
ref_exp_avg_sq = state["ref_exp_avg_sq"] # computed in self._estimate()
|
||||
|
||||
# do ** 0.25, not ** 0.5, because we divide this into two factors:
|
||||
@ -278,7 +282,7 @@ class NeutralGradient(Optimizer):
|
||||
grad_scale = (ref_exp_avg_sq / (exp_avg_sq/bias_correction2 + eps*eps)) ** 0.25
|
||||
|
||||
if random.random() < 0.02:
|
||||
print(f"grad_scale mean = {grad_scale.mean()}, shape = {p.shape}")
|
||||
print(f"grad_scale mean = {grad_scale.mean().item():.43}, shape = {p.shape}")
|
||||
|
||||
cur_grad = grad
|
||||
cur_grad = cur_grad * grad_scale
|
||||
@ -288,8 +292,12 @@ class NeutralGradient(Optimizer):
|
||||
if random.random() < 0.004:
|
||||
# in principle, the cur_grad is supposed to have the same rms as params, on average.
|
||||
cur_grad_rms = (cur_grad**2).mean().sqrt()
|
||||
# _corrected corrects for the overall size of the grad, making cur_grad_rms more similar
|
||||
# to the average, so we can compare with param_rms.
|
||||
cur_grad_rms_corrected = cur_grad_rms * ((exp_avg_sq/bias_correction2).mean().sqrt() /
|
||||
(grad**2).mean().sqrt())
|
||||
param_rms = (p**2).mean().sqrt()
|
||||
print(f"cur_grad_rms={cur_grad_rms}, param_rms={param_rms}")
|
||||
print(f"cur_grad_rms={cur_grad_rms.item():.3e}, corrected_grad_rms={cur_grad_rms_corrected.item():.3e}, param_rms={param_rms.item():.3e}")
|
||||
|
||||
if random.random() < 0.1:
|
||||
prod = (grad*cur_grad).mean()
|
||||
@ -415,7 +423,7 @@ class NeutralGradient(Optimizer):
|
||||
else:
|
||||
param_cov_smoothed = self._estimate_and_smooth_param_cov(p, dim,
|
||||
param_eps,
|
||||
cond_eps=cond_eps,
|
||||
cond_eps=1.0e-04,
|
||||
min_diag_smooth=min_diag_smooth)
|
||||
grad_cov_smoothed = self._smooth_grad_cov(p, grad_cov,
|
||||
eps, norm_step,
|
||||
@ -450,7 +458,11 @@ class NeutralGradient(Optimizer):
|
||||
rank_per_iter = numel // size # maximum rank of each iteration's covaraince
|
||||
safety_factor = 4.0 # should be > 1.0
|
||||
grad_num_steps_needed = safety_factor * size / rank_per_iter
|
||||
return max(beta3, 1 - 1. / grad_num_steps_needed)
|
||||
ans = max(beta3, 1 - 1. / grad_num_steps_needed)
|
||||
if ans != beta3 and random.random() > 0.1:
|
||||
print(f"get_this_beta3: ans={ans}, beta3={beta3}, numel={numel}, size={size}")
|
||||
return ans
|
||||
|
||||
|
||||
def _precondition_grad(self,
|
||||
grad: Tensor,
|
||||
@ -654,6 +666,11 @@ class NeutralGradient(Optimizer):
|
||||
P = torch.matmul(Y, Y.t())
|
||||
|
||||
if random.random() < 0.1:
|
||||
|
||||
# TEMP:
|
||||
_,s,_ = P.svd()
|
||||
print(f"Min,max eig of P: {s.min()},{s.max()}")
|
||||
|
||||
# TODO: remove this testing code.
|
||||
assert (P - P.t()).abs().mean() < 0.01 # make sure symmetric.
|
||||
# testing... note, this is only true modulo "eps"
|
||||
@ -1315,7 +1332,7 @@ def _test_eve_cain():
|
||||
if iter == 0: optim = Eve(m.parameters(), lr=0.003)
|
||||
elif iter == 1: optim = Cain(m.parameters(), lr=0.03)
|
||||
elif iter == 2: optim = NeutralGradient(m.parameters(), lr=0.03, max_size=10)
|
||||
elif iter == 3: optim = NeutralGradient(m.parameters(), lr=0.05, max_size=1000)
|
||||
elif iter == 3: optim = NeutralGradient(m.parameters(), lr=0.03, max_size=1000)
|
||||
scheduler = Eden(optim, lr_batches=200, lr_epochs=10, verbose=False)
|
||||
|
||||
start = timeit.default_timer()
|
||||
|
Loading…
x
Reference in New Issue
Block a user