mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-18 21:44:18 +00:00
Add diagnostics
This commit is contained in:
parent
0679b363b0
commit
57957cc049
@ -21,7 +21,7 @@ import torch
|
||||
import random
|
||||
from torch import Tensor
|
||||
from torch.optim import Optimizer
|
||||
|
||||
from icefall import diagnostics # only for testing code
|
||||
|
||||
class NeutralGradient(Optimizer):
|
||||
"""
|
||||
@ -65,8 +65,8 @@ class NeutralGradient(Optimizer):
|
||||
self,
|
||||
params,
|
||||
lr=1e-2,
|
||||
betas=(0.9, 0.98, 0.99),
|
||||
scale_speed=0.05,
|
||||
betas=(0.9, 0.98, 0.98),
|
||||
scale_speed=0.025,
|
||||
eps=1e-8,
|
||||
param_eps=1.0e-05,
|
||||
cond_eps=1.0e-10,
|
||||
@ -526,7 +526,10 @@ class NeutralGradient(Optimizer):
|
||||
param_cov = torch.matmul(p.t(), p) / p.shape[0]
|
||||
|
||||
# later we may be able to find a more principled formula for this.
|
||||
diag_smooth = max(min_diag_smooth, size / (size + num_outer_products))
|
||||
#if random.random() < 0.2:
|
||||
# print(f"param diag_smooth = {diag_smooth}, shape={p.shape}")
|
||||
#diag_smooth = min_diag_smooth
|
||||
diag_smooth = 0.4
|
||||
diag = param_cov.diag()
|
||||
extra_diag = (diag * diag_smooth) + (diag.max() * cond_eps +
|
||||
param_eps * param_eps)
|
||||
@ -582,10 +585,13 @@ class NeutralGradient(Optimizer):
|
||||
num_outer_products = rank_per_iter * num_iters_in_stats
|
||||
diag_smooth = max(min_diag_smooth,
|
||||
size / (size + num_outer_products))
|
||||
if random.random() < 0.5:
|
||||
print(f"grad diag_smooth = {diag_smooth}, shape={p.shape}")
|
||||
|
||||
diag = grad_cov.diag()
|
||||
extra_diag = (diag * diag_smooth) + (diag.max() * cond_eps +
|
||||
eps * eps)
|
||||
grad_cov.mul_(1-diag_smooth).add_(extra_diag.diag())
|
||||
grad_cov = (grad_cov * (1-diag_smooth)).add_(extra_diag.diag())
|
||||
return grad_cov
|
||||
|
||||
def _get_cov(self, x: Tensor, dim: int) -> Tensor:
|
||||
@ -1338,6 +1344,14 @@ def _test_eve_cain():
|
||||
start = timeit.default_timer()
|
||||
for epoch in range(150):
|
||||
scheduler.step_epoch()
|
||||
|
||||
if epoch == 130:
|
||||
opts = diagnostics.TensorDiagnosticOptions(
|
||||
2 ** 22
|
||||
) # allow 4 megabytes per sub-module
|
||||
diagnostic = diagnostics.attach_diagnostics(m, opts)
|
||||
|
||||
|
||||
for n, (x,y) in enumerate(train_pairs):
|
||||
y_out = m(x)
|
||||
loss = ((y_out - y)**2).mean() * 100.0
|
||||
@ -1356,6 +1370,8 @@ def _test_eve_cain():
|
||||
optim.zero_grad()
|
||||
scheduler.step_batch()
|
||||
|
||||
diagnostic.print_diagnostics()
|
||||
|
||||
stop = timeit.default_timer()
|
||||
print(f"Iter={iter}, Time taken: {stop - start}")
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user