Fix bug for scalar update

This commit is contained in:
Daniel Povey 2022-07-09 10:14:20 +08:00
parent aa2237a793
commit 65bc964854

View File

@ -497,7 +497,6 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
# and project back..
grad = self._project(grad, state, forward=False)
scalar_exp_avg_sq = state["scalar_exp_avg_sq"]
scalar_exp_avg_sq.mul_(beta2).add_((grad**2).mean(),
alpha=1-beta2)
@ -525,11 +524,10 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
grad: Tensor,
state: dict):
"""
A form of the core update for tensors with a small number of elements,
e.g. scalars. This is Adam where, if the numel() > 1, the learning rate
is proportional to the parameter rms value.
A form of the core update for scalar tensors, where we cannot get a good
estimate of the parameter rms.
"""
exp_avg_sq = state["scalar_exp_avg_sq"]
exp_avg_sq = state["exp_avg_sq"]
exp_avg_sq.mul_(beta2).addcmul_(grad, grad,
value=1-beta2)