From 67c402a369cee82abcb4d00be902e1d45d3d3b93 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 15 May 2022 13:28:00 +0800 Subject: [PATCH] Add some debugging/diagnostic code --- .../pruned_transducer_stateless4b/optim.py | 53 ++++++++++++------- 1 file changed, 34 insertions(+), 19 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4b/optim.py b/egs/librispeech/ASR/pruned_transducer_stateless4b/optim.py index e2d949e74..2a4845694 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless4b/optim.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4b/optim.py @@ -122,6 +122,10 @@ def _update_factorization(x: Tensor, x_factorized: Tensor, this_mean = _mean_like(x_norm_var, shape) f = ((1.0 - speed) + speed * this_mean) factors.append(f) + # temp + #import random + #if random.random() < 0.1: + # print("factor norms: ", list((x-1.0).abs().mean().item() for x in factors)) x_factorized *= _product(*factors) # TEMP #import random @@ -149,8 +153,6 @@ def _get_factor_grads(x: Tensor, x_grad: Tensor) -> List[Tensor]: - - def _init_factors(x: Tensor, eps: float = 1.0e-04) -> Tuple[Tensor]: """ Initialize some factors which we will use to normalize the variance of x. @@ -349,16 +351,15 @@ class Abel(Optimizer): if step < 10 or step % 10 == 1: # do this only every 10 steps, to save time. - num_factors = len(factors_exp_avg_sq) _update_factorization(p, factorization, speed=0.1, eps=eps) - factors_sum = None for g, e in zip(factor_grads, factors_exp_avg_sq): update_exp_avg_sq(g, e) - this_denom = (e + eps*eps).sqrt() + this_denom = (e/bias_correction2 + eps*eps).sqrt() + assert g.shape == this_denom.shape factor_delta = g / this_denom factors_sum = (factor_delta if factors_sum is None else factors_sum + factor_delta) @@ -395,16 +396,12 @@ class Abel(Optimizer): # `p * factors_sum` is the contribution from changes in x_factor1 # and x_factor2: again, before taking into account the learning # rate or momentum. - - this_delta = ((grad * factorization / denom) + p * factors_sum) # compute the moving-average change in parameters, and add it to p. delta.mul_(beta1).add_(this_delta, alpha=-lr*(1-beta1)) - if step % 50 == 0 and False: - print("This_delta norm = ", delta.norm()) p.add_(delta) @@ -745,21 +742,27 @@ def _test_abel(): B = 4 T = 2 print("in test_abel") + device = torch.device('cuda') + dtype = torch.float32 + + # these input_magnitudes and output_magnitudes are to test that + # Abel is working as we expect and is able to adjust scales of + # different dims differently. + input_magnitudes = (1.0 * torch.randn(E, dtype=dtype, device=device)).exp() + output_magnitudes = (0.0 * torch.randn(E, dtype=dtype, device=device)).exp() + for iter in [0,1]: - device = torch.device('cuda') - dtype = torch.float32 Linear = torch.nn.Linear if iter == 0 else ScaledLinear m = torch.nn.Sequential(Linear(E, 200), torch.nn.ReLU(), Linear(200, E)).to(device) - - train_pairs = [ (100.0 * torch.randn(B, T, E, device=device, dtype=dtype), - torch.randn(B, T, E, device=device, dtype=dtype)) for _ in range(10) ] + train_pairs = [ (100.0 * torch.randn(B, T, E, device=device, dtype=dtype) * input_magnitudes, + torch.randn(B, T, E, device=device, dtype=dtype) * output_magnitudes) for _ in range(20) ] if iter == 0: optim = Abel(m.parameters(), lr=0.003) else: optim = Eve(m.parameters(), lr=0.003) - scheduler = Eden(optim, lr_batches=300, lr_epochs=2, verbose=False) + scheduler = Eden(optim, lr_batches=300, lr_epochs=20, verbose=False) start = timeit.default_timer() for epoch in range(150): @@ -767,7 +770,7 @@ def _test_abel(): for n, (x,y) in enumerate(train_pairs): y_out = m(x) loss = ((y_out - y)**2).mean() * 100.0 - if n % 10 == 0 and epoch % 10 == 0: + if n == 0 and epoch % 10 == 0: norm1 = '%.2e' % (m[0].weight**2).mean().sqrt().item() norm1b = '%.2e' % (m[0].bias**2).mean().sqrt().item() norm2 = '%.2e' % (m[2].weight**2).mean().sqrt().item() @@ -777,7 +780,7 @@ def _test_abel(): #scale2 = '%.2e' % (m[2].weight_scale.exp().item()) #scale2b = '%.2e' % (m[2].bias_scale.exp().item()) print(f"Iter {iter}, epoch {epoch}, batch {n}, loss {loss.item()}, norms={norm1,norm1b,norm2,norm2b}") # scales={scale1,scale1b,scale2,scale2b} - loss.backward() + loss.log().backward() optim.step() optim.zero_grad() scheduler.step_batch() @@ -786,10 +789,22 @@ def _test_abel(): print(f"Iter={iter}, Time taken: {stop - start}") print("last lr = ", scheduler.get_last_lr()) - print("state dict = ", scheduler.state_dict()) + #print("state dict = ", scheduler.state_dict()) + #print("optim state_dict = ", optim.state_dict()) + print("input_magnitudes = ", input_magnitudes) + print("output_magnitudes = ", output_magnitudes) + + def stddev(x): + return ((x-x.mean())**2).mean().sqrt() + print("Un-normalized input col magnitudes log-stddev: ", stddev((m[0].weight**2).sum(dim=0).sqrt().log())) + print("Normalized input col magnitudes log-stddev: ", stddev(((m[0].weight**2).sum(dim=0).sqrt() * input_magnitudes).log())) + + print("Un-normalized 0-output row magnitudes log-stddev: ", stddev((m[0].weight**2).sum(dim=1).sqrt().log())) + print("Un-normalized 2-output col magnitudes log-stddev: ", stddev((m[0].weight**2).sum(dim=0).sqrt().log())) + print("Un-normalized 2-output row magnitudes log-stddev: ", stddev((m[0].weight**2).sum(dim=1).sqrt().log())) if __name__ == "__main__": _test_abel() - _test_eden() + #_test_eden()