Fix bug for scalar update

This commit is contained in:
Daniel Povey 2022-07-09 10:14:20 +08:00
parent aa2237a793
commit 65bc964854

View File

@ -497,7 +497,6 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
# and project back.. # and project back..
grad = self._project(grad, state, forward=False) grad = self._project(grad, state, forward=False)
scalar_exp_avg_sq = state["scalar_exp_avg_sq"] scalar_exp_avg_sq = state["scalar_exp_avg_sq"]
scalar_exp_avg_sq.mul_(beta2).add_((grad**2).mean(), scalar_exp_avg_sq.mul_(beta2).add_((grad**2).mean(),
alpha=1-beta2) alpha=1-beta2)
@ -525,11 +524,10 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
grad: Tensor, grad: Tensor,
state: dict): state: dict):
""" """
A form of the core update for tensors with a small number of elements, A form of the core update for scalar tensors, where we cannot get a good
e.g. scalars. This is Adam where, if the numel() > 1, the learning rate estimate of the parameter rms.
is proportional to the parameter rms value.
""" """
exp_avg_sq = state["scalar_exp_avg_sq"] exp_avg_sq = state["exp_avg_sq"]
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, exp_avg_sq.mul_(beta2).addcmul_(grad, grad,
value=1-beta2) value=1-beta2)