mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-18 21:44:18 +00:00
More drafts of new method, not tested.
This commit is contained in:
parent
26815d177f
commit
e6d00ee3e4
@ -715,7 +715,7 @@ class LearnedGradient(Optimizer):
|
||||
# needed when we rediagonalize which is not that common.
|
||||
grad_cov.mul_(beta2).add_(torch.matmul(this_g.t(), this_g))
|
||||
|
||||
grad = self._project(grad, forward=True, state)
|
||||
grad = self._project(grad, state, forward=True)
|
||||
|
||||
exp_avg_sq = state["exp_avg_sq"]
|
||||
exp_avg_sq.mul_(beta2).addcmul_(grad, grad,
|
||||
@ -731,7 +731,7 @@ class LearnedGradient(Optimizer):
|
||||
grad = grad / denom
|
||||
|
||||
# and project back..
|
||||
grad = self._project(grad, forward=False, state)
|
||||
grad = self._project(grad, state, forward=False)
|
||||
|
||||
|
||||
scalar_exp_avg_sq = state["scalar_exp_avg_sq"]
|
||||
@ -782,8 +782,8 @@ class LearnedGradient(Optimizer):
|
||||
|
||||
def _project(self,
|
||||
x: Tensor,
|
||||
forward: bool,
|
||||
state: dict) -> Tensor:
|
||||
state: dict,
|
||||
forward: bool) -> Tensor:
|
||||
"""
|
||||
Multiply a tensor x by proj_{dim} on each of its dimensions.
|
||||
If forward == True, converts x from canonical to preconditioned/diagonalized
|
||||
|
Loading…
x
Reference in New Issue
Block a user