mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-18 21:44:18 +00:00
More drafts of new method, not tested.
This commit is contained in:
parent
26815d177f
commit
e6d00ee3e4
@ -715,7 +715,7 @@ class LearnedGradient(Optimizer):
|
|||||||
# needed when we rediagonalize which is not that common.
|
# needed when we rediagonalize which is not that common.
|
||||||
grad_cov.mul_(beta2).add_(torch.matmul(this_g.t(), this_g))
|
grad_cov.mul_(beta2).add_(torch.matmul(this_g.t(), this_g))
|
||||||
|
|
||||||
grad = self._project(grad, forward=True, state)
|
grad = self._project(grad, state, forward=True)
|
||||||
|
|
||||||
exp_avg_sq = state["exp_avg_sq"]
|
exp_avg_sq = state["exp_avg_sq"]
|
||||||
exp_avg_sq.mul_(beta2).addcmul_(grad, grad,
|
exp_avg_sq.mul_(beta2).addcmul_(grad, grad,
|
||||||
@ -731,7 +731,7 @@ class LearnedGradient(Optimizer):
|
|||||||
grad = grad / denom
|
grad = grad / denom
|
||||||
|
|
||||||
# and project back..
|
# and project back..
|
||||||
grad = self._project(grad, forward=False, state)
|
grad = self._project(grad, state, forward=False)
|
||||||
|
|
||||||
|
|
||||||
scalar_exp_avg_sq = state["scalar_exp_avg_sq"]
|
scalar_exp_avg_sq = state["scalar_exp_avg_sq"]
|
||||||
@ -782,8 +782,8 @@ class LearnedGradient(Optimizer):
|
|||||||
|
|
||||||
def _project(self,
|
def _project(self,
|
||||||
x: Tensor,
|
x: Tensor,
|
||||||
forward: bool,
|
state: dict,
|
||||||
state: dict) -> Tensor:
|
forward: bool) -> Tensor:
|
||||||
"""
|
"""
|
||||||
Multiply a tensor x by proj_{dim} on each of its dimensions.
|
Multiply a tensor x by proj_{dim} on each of its dimensions.
|
||||||
If forward == True, converts x from canonical to preconditioned/diagonalized
|
If forward == True, converts x from canonical to preconditioned/diagonalized
|
||||||
|
Loading…
x
Reference in New Issue
Block a user