Add diagnostics

This commit is contained in:
Daniel Povey 2022-06-15 12:39:16 +08:00
parent 0679b363b0
commit 57957cc049

View File

@ -21,7 +21,7 @@ import torch
import random import random
from torch import Tensor from torch import Tensor
from torch.optim import Optimizer from torch.optim import Optimizer
from icefall import diagnostics # only for testing code
class NeutralGradient(Optimizer): class NeutralGradient(Optimizer):
""" """
@ -65,8 +65,8 @@ class NeutralGradient(Optimizer):
self, self,
params, params,
lr=1e-2, lr=1e-2,
betas=(0.9, 0.98, 0.99), betas=(0.9, 0.98, 0.98),
scale_speed=0.05, scale_speed=0.025,
eps=1e-8, eps=1e-8,
param_eps=1.0e-05, param_eps=1.0e-05,
cond_eps=1.0e-10, cond_eps=1.0e-10,
@ -261,7 +261,7 @@ class NeutralGradient(Optimizer):
if step % stats_period == 0: if step % stats_period == 0:
self._accumulate_per_dim_stats(grad, state, beta3, eps) self._accumulate_per_dim_stats(grad, state, beta3, eps)
if step % estimate_period == 0 or step in [25, 50, 200, 400]: if step % estimate_period == 0 or step in [25, 50, 200, 400]:
self._estimate(p, state, beta3, max_size, self._estimate(p, state, beta3, max_size,
stats_period, estimate_period, stats_period, estimate_period,
eps, param_eps, eps, param_eps,
@ -526,7 +526,10 @@ class NeutralGradient(Optimizer):
param_cov = torch.matmul(p.t(), p) / p.shape[0] param_cov = torch.matmul(p.t(), p) / p.shape[0]
# later we may be able to find a more principled formula for this. # later we may be able to find a more principled formula for this.
diag_smooth = max(min_diag_smooth, size / (size + num_outer_products)) #if random.random() < 0.2:
# print(f"param diag_smooth = {diag_smooth}, shape={p.shape}")
#diag_smooth = min_diag_smooth
diag_smooth = 0.4
diag = param_cov.diag() diag = param_cov.diag()
extra_diag = (diag * diag_smooth) + (diag.max() * cond_eps + extra_diag = (diag * diag_smooth) + (diag.max() * cond_eps +
param_eps * param_eps) param_eps * param_eps)
@ -582,10 +585,13 @@ class NeutralGradient(Optimizer):
num_outer_products = rank_per_iter * num_iters_in_stats num_outer_products = rank_per_iter * num_iters_in_stats
diag_smooth = max(min_diag_smooth, diag_smooth = max(min_diag_smooth,
size / (size + num_outer_products)) size / (size + num_outer_products))
if random.random() < 0.5:
print(f"grad diag_smooth = {diag_smooth}, shape={p.shape}")
diag = grad_cov.diag() diag = grad_cov.diag()
extra_diag = (diag * diag_smooth) + (diag.max() * cond_eps + extra_diag = (diag * diag_smooth) + (diag.max() * cond_eps +
eps * eps) eps * eps)
grad_cov.mul_(1-diag_smooth).add_(extra_diag.diag()) grad_cov = (grad_cov * (1-diag_smooth)).add_(extra_diag.diag())
return grad_cov return grad_cov
def _get_cov(self, x: Tensor, dim: int) -> Tensor: def _get_cov(self, x: Tensor, dim: int) -> Tensor:
@ -1338,6 +1344,14 @@ def _test_eve_cain():
start = timeit.default_timer() start = timeit.default_timer()
for epoch in range(150): for epoch in range(150):
scheduler.step_epoch() scheduler.step_epoch()
if epoch == 130:
opts = diagnostics.TensorDiagnosticOptions(
2 ** 22
) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(m, opts)
for n, (x,y) in enumerate(train_pairs): for n, (x,y) in enumerate(train_pairs):
y_out = m(x) y_out = m(x)
loss = ((y_out - y)**2).mean() * 100.0 loss = ((y_out - y)**2).mean() * 100.0
@ -1356,6 +1370,8 @@ def _test_eve_cain():
optim.zero_grad() optim.zero_grad()
scheduler.step_batch() scheduler.step_batch()
diagnostic.print_diagnostics()
stop = timeit.default_timer() stop = timeit.default_timer()
print(f"Iter={iter}, Time taken: {stop - start}") print(f"Iter={iter}, Time taken: {stop - start}")