mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Simplify formula, getting rid of scalar_exp_avg_sq
This commit is contained in:
parent
4f0e219523
commit
41df045773
@ -257,9 +257,6 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
||||
state["param_rms"] = param_rms
|
||||
|
||||
state["scale_exp_avg_sq"] = torch.zeros_like(param_rms)
|
||||
# scalar exp_avg_sq. not the same as scale_exp_avg_sq, this is used to
|
||||
# determine the magnitude of the regular step
|
||||
state["scalar_exp_avg_sq"] = torch.zeros_like(param_rms)
|
||||
state["scale_grads"] = torch.zeros(size_update_period, *param_rms.shape,
|
||||
**kwargs)
|
||||
|
||||
@ -280,7 +277,6 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
||||
|
||||
state["trivial_update"] = trivial_update # so we know whether to zero exp_avg_sq stats.
|
||||
if trivial_update:
|
||||
logging.info(f"Shape={p.shape}, trivial update.") # TODO: remove
|
||||
return
|
||||
|
||||
# "zero_step" being a member of state is the sign that this parameter has
|
||||
@ -511,7 +507,6 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
||||
Zero the exp_avg_sq stats, and set state["zero_step"] to state["step"]
|
||||
"""
|
||||
state["exp_avg_sq"].zero_()
|
||||
state["scalar_exp_avg_sq"].zero_()
|
||||
state["zero_step"] = state["step"]
|
||||
|
||||
def _update_lrs(self,
|
||||
@ -599,6 +594,7 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
||||
# (where the stats permit).
|
||||
scale = (S.clamp(min=eps) / cur_param_var.clamp(min=eps)).sqrt()
|
||||
|
||||
if random.random() < 0.01:
|
||||
logging.info(f"shape={p.shape}, dim={dim}, scale={scale[0].flatten()[::10]}, cur_param_var={cur_param_var[0].flatten()[::10]}, S={S[0].flatten()[::10]}")
|
||||
|
||||
# scale shape: (batch_size, 1, size, 1, 1)
|
||||
@ -654,6 +650,13 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
||||
# Apply the scales in `cur_scales` to Q for each dim; this reflects the
|
||||
# parameter rms values in the parameter-diagonalized space, that we have
|
||||
# estimated in the loop above.
|
||||
#
|
||||
# We normalize the scales in such a way that the Frobenius norm
|
||||
# after projecting (grad / denom) with Q should be unchanged, i.e. the
|
||||
# same as (grad / denom), which is equivalent to having rms=1.0 due
|
||||
# to how denom is constructed. This simplifies the normalization of the overall
|
||||
# scale of the parameter change: we just have to multiply by the learning
|
||||
# rate and param_rms.
|
||||
for dim in range(1, ndim):
|
||||
if cur_scales[dim] is not None:
|
||||
size = p.shape[dim]
|
||||
@ -661,6 +664,16 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
||||
(batch_size, num_blocks, block_size, block_size) = Q.shape
|
||||
|
||||
scale = cur_scales[dim].reshape(batch_size, num_blocks, block_size, 1)
|
||||
# The following normalization step will ensure the Frobenius
|
||||
# norm is unchanged, from applying this scale: at least,
|
||||
# assuming "grad / denom" gives uncorrelated outputs so that
|
||||
# they will have equal variances after projecting to the space
|
||||
# where the parameter var is diagonalized... this is *roughly*
|
||||
# true because the gradients at the point where we compute "grad
|
||||
# / denom" should be decorrelated at least considering
|
||||
# individual tensor dims
|
||||
scale /= _mean(scale**2, exclude_dims=[0], keepdim=True).sqrt()
|
||||
|
||||
# Q is indexed [batch_index, block_index, diagonalized_coordinate, canonical_coordinate],
|
||||
# want to multiply on the diagonalized co-ordinate.
|
||||
# else: Q is indexed [batch_index, canonical_coordinate].
|
||||
@ -825,21 +838,8 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
||||
# and project back..
|
||||
grad = self._project(grad, state, forward=False)
|
||||
|
||||
scalar_exp_avg_sq = state["scalar_exp_avg_sq"]
|
||||
scalar_exp_avg_sq.mul_(beta2).add_(_mean(grad**2,
|
||||
exclude_dims=[0],
|
||||
keepdim=True),
|
||||
alpha=1-beta2)
|
||||
|
||||
|
||||
if bias_correction2 < 0.99:
|
||||
# scalar_exp_avg_sq is also zeroed at step "zero_step", so
|
||||
# use the same bias_correction2.
|
||||
scalar_exp_avg_sq = scalar_exp_avg_sq * (1.0 / bias_correction2)
|
||||
|
||||
denom = scalar_exp_avg_sq.sqrt() + eps # scalar, or [batch_size,1,1..]
|
||||
|
||||
alpha = state["param_rms"] * (1-beta1) * -lr / denom
|
||||
alpha = -lr * (1-beta1) * state["param_rms"]
|
||||
|
||||
delta = state["delta"]
|
||||
delta.add_(grad * alpha)
|
||||
@ -1310,7 +1310,7 @@ class Cain(Optimizer):
|
||||
bias_correction2 = 1 - beta2 ** step
|
||||
denom = (exp_avg_sq.sqrt()).add_(eps)
|
||||
this_delta = grad / denom
|
||||
alpha = -lr*(1-beta1)*(bias_correction2 ** 0.5)
|
||||
alpha = -lr * (1-beta1) * (bias_correction2 ** 0.5)
|
||||
delta.add_(this_delta, alpha=alpha)
|
||||
|
||||
p.add_(delta)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user