diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py index 71bc0b463..f27ff153e 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py @@ -715,7 +715,7 @@ class LearnedGradient(Optimizer): # needed when we rediagonalize which is not that common. grad_cov.mul_(beta2).add_(torch.matmul(this_g.t(), this_g)) - grad = self._project(grad, forward=True, state) + grad = self._project(grad, state, forward=True) exp_avg_sq = state["exp_avg_sq"] exp_avg_sq.mul_(beta2).addcmul_(grad, grad, @@ -731,7 +731,7 @@ class LearnedGradient(Optimizer): grad = grad / denom # and project back.. - grad = self._project(grad, forward=False, state) + grad = self._project(grad, state, forward=False) scalar_exp_avg_sq = state["scalar_exp_avg_sq"] @@ -782,8 +782,8 @@ class LearnedGradient(Optimizer): def _project(self, x: Tensor, - forward: bool, - state: dict) -> Tensor: + state: dict, + forward: bool) -> Tensor: """ Multiply a tensor x by proj_{dim} on each of its dimensions. If forward == True, converts x from canonical to preconditioned/diagonalized