diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py index d8613ccce..197fa5ba1 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py @@ -60,7 +60,7 @@ class LearnedGradient(Optimizer): params, lr=3e-02, size_lr_scale=0.1, - meta_lr_scale=0.0, + meta_lr_scale=0.05, betas=(0.9, 0.98), eps=1.0e-08, size_update_period=1, @@ -189,9 +189,9 @@ class LearnedGradient(Optimizer): state[f"Q_{dim}"] = torch.eye(size, **kwargs) # proj_grad_{dim} is the gradient w.r.t. `proj`, - # which is a matrix we introduce to the right of p, i.e. - # instead of M == torch.matmul(N, p) we view it as - # M == torch.matmul(torch.matmul(N, torch.matmul(p, proj))) + # which is a matrix we introduce to the right of Q, i.e. + # instead of M == torch.matmul(N, Q) we view it as + # M == torch.matmul(torch.matmul(N, torch.matmul(Q, proj))) # or equivalently M == torch.matmul(M_underlying, proj) # where M_underlying is numerically equal to M, and proj is numerically # equal to I, but they are viewed as separate variables. @@ -377,19 +377,19 @@ class LearnedGradient(Optimizer): # which will be of shape (size, size). For purposes of our update, if we view # the parameter matrix M as (M_underlying proj) with # proj being numerically equal to I but treated as a variable, i.e. as - # matmul(M_underlying, proj), then instead of the loss being tr(M_grad^T M), we can write + # M = matmul(M_underlying, proj), then instead of the loss being tr(M_grad^T M), we can write # it as tr(M_grad^T M_underlying proj), which can also be written as tr(proj_grad^T proj) - # where proj_grad == M_underlying^T M_grad. + # where proj_grad == M_underlying^T M_grad == M^T M_grad # - # Now, proj_grad is convenient to compute but it's not in a very normalized - # space, we want to normalize it before doing gradient descent on it. + # Now, proj_grad is convenient to compute, being invariant of Q, but it's not in a very normalized + # space; we want to normalize it before doing gradient descent on it. # We're going to view M as # M == N Q, # where Q is our Q_{dim} "learning-rate" matrix, of shape (size, size) indexed # [diagonalized_index, canonical_index] (see the code), and N == M Q^{-1}. - # It's more convenient to do gradient descent on a `proj` matrix that's between N and Q, - # i.e. define `proj2`, also numerically equal to I but treated as a variable, - # and have: + # It's likely better numerically to do gradient descent on a `proj2` matrix + # that's between N and Q, + # i.e.: # M == N proj2 Q # We want to obtain the derivative w.r.t. proj2. # So the linearized pseudo-loss is @@ -412,20 +412,13 @@ class LearnedGradient(Optimizer): # Q += proj2_delta Q (eq.2) # # We also need to update M itself; this is easiest done by computing proj_delta - # (which is the difference proj from I). The relationship between proj and proj2 + # (which is the difference of proj from I). The relationship between proj and proj2 # is given by equating # M == N proj2 Q == N Q proj, - # so proj2 Q == Q proj - # proj == Q^{-1} proj2 Q + # so proj2 Q == Q proj + # proj == Q^{-1} proj2 Q # and subtracting out the constant "I" part, # proj_delta == Q^{-1} proj2_delta Q (eq.3) - # ... so the change in Q is given by: - # Q := Q (I + proj_delta), - # Q += Q proj_delta - # Q_delta = Q proj_delta - # and looking at how we compute proj_delta (eq.3), we notice the right-hand subexpression - # is the sam as p_delta, i.e.: - # Q_delta = proj2_delta Q (eq.4) # # and then because we view the parameter matrix as # "M proj", @@ -459,13 +452,11 @@ class LearnedGradient(Optimizer): # we'll take into account the factor of -meta_lr * (1-beta1) later on; # actually proj2_delta is the negative of a parameter change. - proj2_delta = proj2_grad_var / denom + proj2_delta = proj2_grad / denom - # See (eq.4), Q_delta = proj2_delta q + # See (eq.2), Q += proj2_delta Q Q_delta = torch.matmul(proj2_delta, Q) - assert not torch.all(proj2_delta == 0) - # See (eq.3), proj_delta = Q^{-1} proj2_delta Q = Q^{-1} Q_delta try: proj_delta = torch.linalg.solve(Q, Q_delta) @@ -483,8 +474,8 @@ class LearnedGradient(Optimizer): # are the final changes, the only 2 we make in this loop that have # side effects. - # delta_scale will make it update the learning rates faster than it otherwise would. - delta_scale=0.2 + # delta_scale < 1 will make it update the learning rates faster than it otherwise would. + delta_scale=1.0 delta.add_(this_delta, alpha=-delta_scale*meta_lr*(1-beta1)) # there is no momentum on Q. Q.add_(Q_delta, alpha=-meta_lr) @@ -586,7 +577,7 @@ class LearnedGradient(Optimizer): Q[:] = torch.matmul(U * S_new, V.t()) if random.random() < 0.03: subsample = max(1, S.numel() // 20) - logging.info(f"shape={tuple(p.shape)}, dim={dim}, modifed S from {S[::subsample]} to {S_new[::subsample]}") + logging.info(f"shape={tuple(p.shape)}, dim={dim}, modified S from {S[::subsample]} to {S_new[::subsample]}") if True: # This block does the actual diagonalization.