From fb36712e6b2541e35c25bb1dfa2284d8d4925fe9 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 8 Jul 2022 05:22:24 +0800 Subject: [PATCH] Another bug fix, regarding Q being transposed. --- .../ASR/pruned_transducer_stateless7/optim.py | 67 ++++++++++--------- 1 file changed, 36 insertions(+), 31 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py index 8d16d7b1d..0293a3d3a 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py @@ -60,7 +60,7 @@ class LearnedGradient(Optimizer): params, lr=3e-02, size_lr_scale=0.1, - meta_lr_scale=0.1, + meta_lr_scale=0.2, betas=(0.9, 0.98), eps=1.0e-08, size_update_period=1, @@ -183,7 +183,7 @@ class LearnedGradient(Optimizer): # re-estimate when we "rediagonalize" (this diagonalizes # the gradient covariance). # If our parameter matrix M is of shape (..., size), - # we'll view M as being M == torch.matmul(N, q) + # we'll view M as being M == torch.matmul(N, Q) # so p can be interpreted as having shape # (size, size), interpreted as being indexed [diagonalized_coordinate, canonical_coordinate] # by "diagonalize" we mean that we diagonalize the gradient covariance. @@ -255,6 +255,7 @@ class LearnedGradient(Optimizer): if step % lr_est_period == 0: self._update_lrs(group, p, state) if step % (lr_est_period * diagonalize_period) == 0: + logging.info("Diagonalizing") self._diagonalize_lrs(group, p, state) self._zero_exp_avg_sq(state) if True: @@ -555,8 +556,8 @@ class LearnedGradient(Optimizer): S_new *= 1.0 / S_new.mean() # normalize so mean is 1.0 S_new.clamp_(lr_mat_min, lr_mat_max) # apply limits once more. # Reconstruct Q with the modified S. - Q[:] = torch.matmul(U * S, V.t()) - if random.random() < 0.1: + Q[:] = torch.matmul(U * S_new, V.t()) + if random.random() < 0.03: subsample = max(1, S.numel() // 20) logging.info(f"shape={tuple(p.shape)}, dim={dim}, modifed S from {S[::subsample]} to {S_new[::subsample]}") @@ -580,7 +581,7 @@ class LearnedGradient(Optimizer): # index). grad_cov = state[f"grad_cov_{dim}"] - N_grad_cov = torch.matmul(Q, torch.matmul(grad_cov, Q)) + N_grad_cov = torch.matmul(Q, torch.matmul(grad_cov, Q.t())) N_grad_cov = N_grad_cov + N_grad_cov.t() # ensure symmetric U, S, V = N_grad_cov.svd() @@ -590,7 +591,7 @@ class LearnedGradient(Optimizer): # Now, we can diagonalize N_grad_cov with: # U^T N_grad_cov U == S. # N_grad_cov is a sum of N_grad^T N_grad. - # So U^T N_grad^T N_grad U is diagonal. + # We know U^T N_grad_cov U is diagonal, so U^T N_grad^T N_grad U is diagonal. # The linearized pseudo-loss can be written as tr(N_grad^T N_grad). # This can be written as tr(U U^T N_grad^T N_grad), since U U^T == I, # @@ -601,7 +602,8 @@ class LearnedGradient(Optimizer): # (hat_N means \hat{N}, or N with a hat on it). # So if we interpret hat_N = N U, the gradient covariance w.r.t. # hat_N will be diagonalized. We also modify Q to hat_Q when - # we modify hat_N, to keep the product M = N Q = N U U^T Q = hat_N hat_Q + # we modify hat_N, to keep the product M unchanged: + # M = N Q = N U U^T Q = hat_N hat_Q # This can be done by setting # hat_Q = U^T Q (eq.10) # @@ -814,8 +816,10 @@ class LearnedGradient(Optimizer): if x.shape[dim] == 1: continue Q = state[f"Q_{dim}"] - if not forward: - # Q is indexed [canonical_index, diagonalized_index] + if forward: + # Q is indexed [diagonalized_index, canonical_index]; in the forward + # direction we want to change canonical to diagonalized index, so have + # to transpose. Q = Q.t() # TODO: could possibly somehow force the output format to be unchanged. x = x.transpose(-1, dim) @@ -1156,7 +1160,7 @@ class LearnedGradient(Optimizer): try: P = 0.5 * (P + P.t()) _,s,_ = P.svd() - print(f"Min,max eig of P: {s.min()},{s.max()}") + logging.info(f"Min,max eig of P: {s.min()},{s.max()}") except: pass # testing... note, this is only true modulo "eps" @@ -1167,7 +1171,7 @@ class LearnedGradient(Optimizer): # Roundoff can cause significant differences, so use a fairly large # threshold of 0.001. We may increase this later or even remove the check. if not C_diff.abs().mean() < 0.01 * C_smoothed.diag().mean(): - print(f"Warning: large C_diff: {C_diff.abs().mean()}, C diag mean: {C_smoothed.diag().mean()}") + logging.info(f"Warning: large C_diff: {C_diff.abs().mean()}, C diag mean: {C_smoothed.diag().mean()}") return P @@ -1545,7 +1549,7 @@ class Cain(Optimizer): var_factor = var_factor.mean(dim=dims, keepdim=True) #if random.random() < 0.01: - # print(f"shape={p.shape}, var_factor={var_factor}") + # logging.info(f"shape={p.shape}, var_factor={var_factor}") param_rms.mul_(var_factor.sqrt()) @@ -1648,7 +1652,7 @@ class LRScheduler(object): def print_lr(self, is_verbose, group, lr): """Display the current learning rate.""" if is_verbose: - print( + logging.info( f"Epoch={self.epoch}, batch={self.batch}: adjusting learning rate" f" of group {group} to {lr:.4e}." ) @@ -1797,7 +1801,7 @@ class Eve(Optimizer): if random.random() < 0.0005: step = (exp_avg/denom) * step_size - print(f"Delta rms = {(step**2).mean().item()}, shape = {step.shape}") + logging.info(f"Delta rms = {(step**2).mean().item()}, shape = {step.shape}") return loss @@ -1863,8 +1867,8 @@ def _test_eden(): scheduler.step_batch() optim.zero_grad() - print("last lr = ", scheduler.get_last_lr()) - print("state dict = ", scheduler.state_dict()) + logging.info(f"last lr = {scheduler.get_last_lr()}") + logging.info(f"state dict = {scheduler.state_dict()}") def _test_eve_cain(): @@ -1873,7 +1877,7 @@ def _test_eve_cain(): E = 100 B = 4 T = 2 - print("in test_eve_cain") + logging.info("in test_eve_cain") device = torch.device('cuda') dtype = torch.float32 @@ -1921,7 +1925,8 @@ def _test_eve_cain(): avg_loss = loss.item() else: avg_loss = 0.95 * avg_loss + 0.05 * loss.item() - if n == 0 and epoch % 5 == 0: + #if n == 0 and epoch % 5 == 0: + if True: norm1 = '%.2e' % (m[0].weight**2).mean().sqrt().item() norm1b = '%.2e' % (m[0].bias**2).mean().sqrt().item() norm2 = '%.2e' % (m[2].weight**2).mean().sqrt().item() @@ -1931,7 +1936,7 @@ def _test_eve_cain(): #scale2 = '%.2e' % (m[2].weight_scale.exp().item()) #scale2b = '%.2e' % (m[2].bias_scale.exp().item()) lr = scheduler.get_last_lr()[0] - print(f"Iter {iter}, epoch {epoch}, batch {n}, avg_loss {avg_loss:.4g}, lr={lr:.4e}, norms={norm1,norm1b,norm2,norm2b}") # scales={scale1,scale1b,scale2,scale2b} + logging.info(f"Iter {iter}, epoch {epoch}, batch {n}, avg_loss {avg_loss:.4g}, lr={lr:.4e}, norms={norm1,norm1b,norm2,norm2b}") # scales={scale1,scale1b,scale2,scale2b} loss.log().backward() optim.step() optim.zero_grad() @@ -1940,24 +1945,24 @@ def _test_eve_cain(): #diagnostic.print_diagnostics() stop = timeit.default_timer() - print(f"Iter={iter}, Time taken: {stop - start}") + logging.info(f"Iter={iter}, Time taken: {stop - start}") - print("last lr = ", scheduler.get_last_lr()) - #print("state dict = ", scheduler.state_dict()) - #print("optim state_dict = ", optim.state_dict()) - print("input_magnitudes = ", input_magnitudes) - print("output_magnitudes = ", output_magnitudes) + logging.info(f"last lr = {scheduler.get_last_lr()}") + #logging.info("state dict = ", scheduler.state_dict()) + #logging.info("optim state_dict = ", optim.state_dict()) + logging.info(f"input_magnitudes = {input_magnitudes}") + logging.info(f"output_magnitudes = {output_magnitudes}") def stddev(x): return ((x-x.mean())**2).mean().sqrt() - print("Un-normalized input col magnitudes log-stddev: ", stddev((m[0].weight**2).sum(dim=0).sqrt().log())) - print("Normalized input col magnitudes log-stddev: ", stddev(((m[0].weight**2).sum(dim=0).sqrt() * input_magnitudes).log())) + logging.info(f"Un-normalized input col magnitudes log-stddev: {stddev((m[0].weight**2).sum(dim=0).sqrt().log())}") + logging.info(f"Normalized input col magnitudes log-stddev: {stddev(((m[0].weight**2).sum(dim=0).sqrt() * input_magnitudes).log())}") - print("Un-normalized 0-output row magnitudes log-stddev: ", stddev((m[0].weight**2).sum(dim=1).sqrt().log())) - print("Un-normalized 2-input col magnitudes log-stddev: ", stddev((m[2].weight**2).sum(dim=0).sqrt().log())) - print("Un-normalized 2-output row magnitudes log-stddev: ", stddev((m[2].weight**2).sum(dim=1).sqrt().log())) + logging.info(f"Un-normalized 0-output row magnitudes log-stddev: {stddev((m[0].weight**2).sum(dim=1).sqrt().log())}") + logging.info("Un-normalized 2-input col magnitudes log-stddev: {stddev((m[2].weight**2).sum(dim=0).sqrt().log())}") + logging.info("Un-normalized 2-output row magnitudes log-stddev: {stddev((m[2].weight**2).sum(dim=1).sqrt().log())}") - print("Normalized output row magnitudes log-stddev: ", stddev(((m[2].weight**2).sum(dim=1).sqrt() / output_magnitudes).log())) + logging.info("Normalized output row magnitudes log-stddev: {stddev(((m[2].weight**2).sum(dim=1).sqrt() / output_magnitudes).log())}")