mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-19 05:54:20 +00:00
Fix bug for scalar update
This commit is contained in:
parent
aa2237a793
commit
65bc964854
@ -497,7 +497,6 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
|||||||
# and project back..
|
# and project back..
|
||||||
grad = self._project(grad, state, forward=False)
|
grad = self._project(grad, state, forward=False)
|
||||||
|
|
||||||
|
|
||||||
scalar_exp_avg_sq = state["scalar_exp_avg_sq"]
|
scalar_exp_avg_sq = state["scalar_exp_avg_sq"]
|
||||||
scalar_exp_avg_sq.mul_(beta2).add_((grad**2).mean(),
|
scalar_exp_avg_sq.mul_(beta2).add_((grad**2).mean(),
|
||||||
alpha=1-beta2)
|
alpha=1-beta2)
|
||||||
@ -525,11 +524,10 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
|||||||
grad: Tensor,
|
grad: Tensor,
|
||||||
state: dict):
|
state: dict):
|
||||||
"""
|
"""
|
||||||
A form of the core update for tensors with a small number of elements,
|
A form of the core update for scalar tensors, where we cannot get a good
|
||||||
e.g. scalars. This is Adam where, if the numel() > 1, the learning rate
|
estimate of the parameter rms.
|
||||||
is proportional to the parameter rms value.
|
|
||||||
"""
|
"""
|
||||||
exp_avg_sq = state["scalar_exp_avg_sq"]
|
exp_avg_sq = state["exp_avg_sq"]
|
||||||
exp_avg_sq.mul_(beta2).addcmul_(grad, grad,
|
exp_avg_sq.mul_(beta2).addcmul_(grad, grad,
|
||||||
value=1-beta2)
|
value=1-beta2)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user