Simplify formula, getting rid of scalar_exp_avg_sq

This commit is contained in:
Daniel Povey 2022-07-11 17:14:12 -07:00
parent 4f0e219523
commit 41df045773

View File

@ -257,9 +257,6 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
state["param_rms"] = param_rms state["param_rms"] = param_rms
state["scale_exp_avg_sq"] = torch.zeros_like(param_rms) state["scale_exp_avg_sq"] = torch.zeros_like(param_rms)
# scalar exp_avg_sq. not the same as scale_exp_avg_sq, this is used to
# determine the magnitude of the regular step
state["scalar_exp_avg_sq"] = torch.zeros_like(param_rms)
state["scale_grads"] = torch.zeros(size_update_period, *param_rms.shape, state["scale_grads"] = torch.zeros(size_update_period, *param_rms.shape,
**kwargs) **kwargs)
@ -280,7 +277,6 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
state["trivial_update"] = trivial_update # so we know whether to zero exp_avg_sq stats. state["trivial_update"] = trivial_update # so we know whether to zero exp_avg_sq stats.
if trivial_update: if trivial_update:
logging.info(f"Shape={p.shape}, trivial update.") # TODO: remove
return return
# "zero_step" being a member of state is the sign that this parameter has # "zero_step" being a member of state is the sign that this parameter has
@ -511,7 +507,6 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
Zero the exp_avg_sq stats, and set state["zero_step"] to state["step"] Zero the exp_avg_sq stats, and set state["zero_step"] to state["step"]
""" """
state["exp_avg_sq"].zero_() state["exp_avg_sq"].zero_()
state["scalar_exp_avg_sq"].zero_()
state["zero_step"] = state["step"] state["zero_step"] = state["step"]
def _update_lrs(self, def _update_lrs(self,
@ -599,6 +594,7 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
# (where the stats permit). # (where the stats permit).
scale = (S.clamp(min=eps) / cur_param_var.clamp(min=eps)).sqrt() scale = (S.clamp(min=eps) / cur_param_var.clamp(min=eps)).sqrt()
if random.random() < 0.01:
logging.info(f"shape={p.shape}, dim={dim}, scale={scale[0].flatten()[::10]}, cur_param_var={cur_param_var[0].flatten()[::10]}, S={S[0].flatten()[::10]}") logging.info(f"shape={p.shape}, dim={dim}, scale={scale[0].flatten()[::10]}, cur_param_var={cur_param_var[0].flatten()[::10]}, S={S[0].flatten()[::10]}")
# scale shape: (batch_size, 1, size, 1, 1) # scale shape: (batch_size, 1, size, 1, 1)
@ -654,6 +650,13 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
# Apply the scales in `cur_scales` to Q for each dim; this reflects the # Apply the scales in `cur_scales` to Q for each dim; this reflects the
# parameter rms values in the parameter-diagonalized space, that we have # parameter rms values in the parameter-diagonalized space, that we have
# estimated in the loop above. # estimated in the loop above.
#
# We normalize the scales in such a way that the Frobenius norm
# after projecting (grad / denom) with Q should be unchanged, i.e. the
# same as (grad / denom), which is equivalent to having rms=1.0 due
# to how denom is constructed. This simplifies the normalization of the overall
# scale of the parameter change: we just have to multiply by the learning
# rate and param_rms.
for dim in range(1, ndim): for dim in range(1, ndim):
if cur_scales[dim] is not None: if cur_scales[dim] is not None:
size = p.shape[dim] size = p.shape[dim]
@ -661,6 +664,16 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
(batch_size, num_blocks, block_size, block_size) = Q.shape (batch_size, num_blocks, block_size, block_size) = Q.shape
scale = cur_scales[dim].reshape(batch_size, num_blocks, block_size, 1) scale = cur_scales[dim].reshape(batch_size, num_blocks, block_size, 1)
# The following normalization step will ensure the Frobenius
# norm is unchanged, from applying this scale: at least,
# assuming "grad / denom" gives uncorrelated outputs so that
# they will have equal variances after projecting to the space
# where the parameter var is diagonalized... this is *roughly*
# true because the gradients at the point where we compute "grad
# / denom" should be decorrelated at least considering
# individual tensor dims
scale /= _mean(scale**2, exclude_dims=[0], keepdim=True).sqrt()
# Q is indexed [batch_index, block_index, diagonalized_coordinate, canonical_coordinate], # Q is indexed [batch_index, block_index, diagonalized_coordinate, canonical_coordinate],
# want to multiply on the diagonalized co-ordinate. # want to multiply on the diagonalized co-ordinate.
# else: Q is indexed [batch_index, canonical_coordinate]. # else: Q is indexed [batch_index, canonical_coordinate].
@ -825,21 +838,8 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
# and project back.. # and project back..
grad = self._project(grad, state, forward=False) grad = self._project(grad, state, forward=False)
scalar_exp_avg_sq = state["scalar_exp_avg_sq"]
scalar_exp_avg_sq.mul_(beta2).add_(_mean(grad**2,
exclude_dims=[0],
keepdim=True),
alpha=1-beta2)
alpha = -lr * (1-beta1) * state["param_rms"]
if bias_correction2 < 0.99:
# scalar_exp_avg_sq is also zeroed at step "zero_step", so
# use the same bias_correction2.
scalar_exp_avg_sq = scalar_exp_avg_sq * (1.0 / bias_correction2)
denom = scalar_exp_avg_sq.sqrt() + eps # scalar, or [batch_size,1,1..]
alpha = state["param_rms"] * (1-beta1) * -lr / denom
delta = state["delta"] delta = state["delta"]
delta.add_(grad * alpha) delta.add_(grad * alpha)