diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py index 7afb29690..8b5aeac2d 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py @@ -60,7 +60,7 @@ class LearnedGradient(Optimizer): params, lr=3e-02, size_lr_scale=0.1, - meta_lr_scale=0.05, + meta_lr_scale=0.01, betas=(0.9, 0.98), eps=1.0e-08, size_update_period=1, @@ -379,9 +379,10 @@ class LearnedGradient(Optimizer): # proj being numerically equal to I but treated as a variable, i.e. as # M = matmul(M_underlying, proj), then instead of the loss being tr(M_grad^T M), we can write # it as tr(M_grad^T M_underlying proj), which can also be written as tr(proj_grad^T proj) - # where proj_grad == M_underlying^T M_grad == M^T M_grad + # where proj_grad == (M_grad^T M_underlying)^T == M_underlying^T M_grad == M^T M_grad # - # Now, proj_grad is convenient to compute, being invariant of Q, but it's not in a very normalized + # Now, proj_grad is convenient to compute, being independent of Q, + # but it's not in a very normalized # space; we want to normalize it before doing gradient descent on it. # We're going to view M as # M == N Q, @@ -430,6 +431,25 @@ class LearnedGradient(Optimizer): # M_delta = M proj_delta (eq.5) proj_grad = state[f"proj_grad_{dim}"] Q = state[f"Q_{dim}"] + if True: + # Simpler update, do it to the right of Q. + proj_grad[:] = proj_grad + proj_grad.t() + proj_grad_var = self._get_matrix_var(proj_grad, group["eps"]) + denom = proj_grad_var.sqrt() + proj_delta = proj_grad / denom + Q_delta = torch.matmul(Q, proj_delta) + M_delta = torch.matmul(M, proj_delta) + M_delta_full = M_delta.reshape(M_full.shape) + this_delta = M_delta_full.transpose(dim, -1) + # delta_scale < 1 will make it update the learning rates faster than it otherwise would, + # as we'll reach equilbrium with M less rapidly. + delta_scale=1.0 + delta.add_(this_delta, alpha=-delta_scale*meta_lr*(1-beta1)) + # there is no momentum on Q. + Q.add_(Q_delta, alpha=-meta_lr) + proj_grad.zero_() + continue + # Next we want to implement (eq.1), proj2_grad = Q^{-T} proj_grad Q^T # torch.solve(A, B) returns A^{-1} B.