From e6d00ee3e4d9d431fd180793dcc80d079963d5ba Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 6 Jul 2022 23:05:06 -0700 Subject: [PATCH] More drafts of new method, not tested. --- egs/librispeech/ASR/pruned_transducer_stateless7/optim.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py index 71bc0b463..f27ff153e 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py @@ -715,7 +715,7 @@ class LearnedGradient(Optimizer): # needed when we rediagonalize which is not that common. grad_cov.mul_(beta2).add_(torch.matmul(this_g.t(), this_g)) - grad = self._project(grad, forward=True, state) + grad = self._project(grad, state, forward=True) exp_avg_sq = state["exp_avg_sq"] exp_avg_sq.mul_(beta2).addcmul_(grad, grad, @@ -731,7 +731,7 @@ class LearnedGradient(Optimizer): grad = grad / denom # and project back.. - grad = self._project(grad, forward=False, state) + grad = self._project(grad, state, forward=False) scalar_exp_avg_sq = state["scalar_exp_avg_sq"] @@ -782,8 +782,8 @@ class LearnedGradient(Optimizer): def _project(self, x: Tensor, - forward: bool, - state: dict) -> Tensor: + state: dict, + forward: bool) -> Tensor: """ Multiply a tensor x by proj_{dim} on each of its dimensions. If forward == True, converts x from canonical to preconditioned/diagonalized