diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py index 8b5aeac2d..e624d53e7 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py @@ -60,7 +60,7 @@ class LearnedGradient(Optimizer): params, lr=3e-02, size_lr_scale=0.1, - meta_lr_scale=0.01, + meta_lr_scale=0.0, betas=(0.9, 0.98), eps=1.0e-08, size_update_period=1, @@ -431,8 +431,14 @@ class LearnedGradient(Optimizer): # M_delta = M proj_delta (eq.5) proj_grad = state[f"proj_grad_{dim}"] Q = state[f"Q_{dim}"] + + # delta_scale < 1 will make it update the learning rates faster than it otherwise would, + # as we'll reach equilbrium with M less rapidly. + delta_scale=0.5 + if True: # Simpler update, do it to the right of Q. + # symmetrize, otherwise there is a bad interaction with the regular update. proj_grad[:] = proj_grad + proj_grad.t() proj_grad_var = self._get_matrix_var(proj_grad, group["eps"]) denom = proj_grad_var.sqrt() @@ -441,9 +447,6 @@ class LearnedGradient(Optimizer): M_delta = torch.matmul(M, proj_delta) M_delta_full = M_delta.reshape(M_full.shape) this_delta = M_delta_full.transpose(dim, -1) - # delta_scale < 1 will make it update the learning rates faster than it otherwise would, - # as we'll reach equilbrium with M less rapidly. - delta_scale=1.0 delta.add_(this_delta, alpha=-delta_scale*meta_lr*(1-beta1)) # there is no momentum on Q. Q.add_(Q_delta, alpha=-meta_lr) @@ -463,6 +466,10 @@ class LearnedGradient(Optimizer): # (eq.1), proj2_grad = Q^{-T} proj_grad Q^T proj2_grad = torch.matmul(Q_invt_proj_grad, Q.t()) + # symmetrize, otherwise there is a bad interaction with the regular + # update in self._step(). + proj2_grad = proj2_grad + proj2_grad.t() + # for now just estimate proj2_grad_var from proj2_grad; later, we # may try aggregating it across iterations. This is like exp_avg_sq # in Adam. @@ -494,9 +501,6 @@ class LearnedGradient(Optimizer): # are the final changes, the only 2 we make in this loop that have # side effects. - # delta_scale < 1 will make it update the learning rates faster than it otherwise would, - # as we'll reach equilbrium with M less rapidly. - delta_scale=1.0 delta.add_(this_delta, alpha=-delta_scale*meta_lr*(1-beta1)) # there is no momentum on Q. Q.add_(Q_delta, alpha=-meta_lr)