Make it less verbose; fix scale_speed setting; testing min_diag_smooth=1.0 for debug

This commit is contained in:
Daniel Povey 2022-06-15 13:02:04 +08:00
parent bc5d72b0f3
commit 860322bf30

View File

@ -66,12 +66,12 @@ class NeutralGradient(Optimizer):
params, params,
lr=1e-2, lr=1e-2,
betas=(0.9, 0.98, 0.98), betas=(0.9, 0.98, 0.98),
scale_speed=0.05, scale_speed=0.1,
eps=1e-8, eps=1e-8,
param_eps=1.0e-05, param_eps=1.0e-05,
cond_eps=1.0e-10, cond_eps=1.0e-10,
param_max=10.0, param_max=10.0,
min_diag_smooth=0.5, min_diag_smooth=1.0,
max_size=1023, max_size=1023,
stats_period=1, stats_period=1,
estimate_period=200, estimate_period=200,
@ -109,7 +109,7 @@ class NeutralGradient(Optimizer):
lr=lr, lr=lr,
betas=betas, betas=betas,
eps=eps, eps=eps,
scale_speed=0.05, scale_speed=scale_speed,
param_eps=param_eps, param_eps=param_eps,
cond_eps=cond_eps, cond_eps=cond_eps,
min_diag_smooth=min_diag_smooth, min_diag_smooth=min_diag_smooth,
@ -237,6 +237,8 @@ class NeutralGradient(Optimizer):
# p = underlying_param * scale.exp(), # p = underlying_param * scale.exp(),
# delta is the change in `scale`. # delta is the change in `scale`.
scale_delta = scale_alpha * (scale_deriv / scale_denom) scale_delta = scale_alpha * (scale_deriv / scale_denom)
#if random.random() < 0.01:
# print("scale_delta = ", scale_delta)
delta.add_(p, alpha=scale_delta) delta.add_(p, alpha=scale_delta)
@ -281,8 +283,8 @@ class NeutralGradient(Optimizer):
grad_scale = (ref_exp_avg_sq / (exp_avg_sq/bias_correction2 + eps*eps)) ** 0.25 grad_scale = (ref_exp_avg_sq / (exp_avg_sq/bias_correction2 + eps*eps)) ** 0.25
if random.random() < 0.02: #if random.random() < 0.02:
print(f"grad_scale mean = {grad_scale.mean().item():.43}, shape = {p.shape}") # print(f"grad_scale mean = {grad_scale.mean().item():.43}, shape = {p.shape}")
cur_grad = grad cur_grad = grad
cur_grad = cur_grad * grad_scale cur_grad = cur_grad * grad_scale
@ -1377,7 +1379,7 @@ def _test_eve_cain():
optim.zero_grad() optim.zero_grad()
scheduler.step_batch() scheduler.step_batch()
diagnostic.print_diagnostics() #diagnostic.print_diagnostics()
stop = timeit.default_timer() stop = timeit.default_timer()
print(f"Iter={iter}, Time taken: {stop - start}") print(f"Iter={iter}, Time taken: {stop - start}")