mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Simplify formula, getting rid of scalar_exp_avg_sq
This commit is contained in:
parent
4f0e219523
commit
41df045773
@ -257,9 +257,6 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
|||||||
state["param_rms"] = param_rms
|
state["param_rms"] = param_rms
|
||||||
|
|
||||||
state["scale_exp_avg_sq"] = torch.zeros_like(param_rms)
|
state["scale_exp_avg_sq"] = torch.zeros_like(param_rms)
|
||||||
# scalar exp_avg_sq. not the same as scale_exp_avg_sq, this is used to
|
|
||||||
# determine the magnitude of the regular step
|
|
||||||
state["scalar_exp_avg_sq"] = torch.zeros_like(param_rms)
|
|
||||||
state["scale_grads"] = torch.zeros(size_update_period, *param_rms.shape,
|
state["scale_grads"] = torch.zeros(size_update_period, *param_rms.shape,
|
||||||
**kwargs)
|
**kwargs)
|
||||||
|
|
||||||
@ -280,7 +277,6 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
|||||||
|
|
||||||
state["trivial_update"] = trivial_update # so we know whether to zero exp_avg_sq stats.
|
state["trivial_update"] = trivial_update # so we know whether to zero exp_avg_sq stats.
|
||||||
if trivial_update:
|
if trivial_update:
|
||||||
logging.info(f"Shape={p.shape}, trivial update.") # TODO: remove
|
|
||||||
return
|
return
|
||||||
|
|
||||||
# "zero_step" being a member of state is the sign that this parameter has
|
# "zero_step" being a member of state is the sign that this parameter has
|
||||||
@ -511,7 +507,6 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
|||||||
Zero the exp_avg_sq stats, and set state["zero_step"] to state["step"]
|
Zero the exp_avg_sq stats, and set state["zero_step"] to state["step"]
|
||||||
"""
|
"""
|
||||||
state["exp_avg_sq"].zero_()
|
state["exp_avg_sq"].zero_()
|
||||||
state["scalar_exp_avg_sq"].zero_()
|
|
||||||
state["zero_step"] = state["step"]
|
state["zero_step"] = state["step"]
|
||||||
|
|
||||||
def _update_lrs(self,
|
def _update_lrs(self,
|
||||||
@ -599,7 +594,8 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
|||||||
# (where the stats permit).
|
# (where the stats permit).
|
||||||
scale = (S.clamp(min=eps) / cur_param_var.clamp(min=eps)).sqrt()
|
scale = (S.clamp(min=eps) / cur_param_var.clamp(min=eps)).sqrt()
|
||||||
|
|
||||||
logging.info(f"shape={p.shape}, dim={dim}, scale={scale[0].flatten()[::10]}, cur_param_var={cur_param_var[0].flatten()[::10]}, S={S[0].flatten()[::10]}")
|
if random.random() < 0.01:
|
||||||
|
logging.info(f"shape={p.shape}, dim={dim}, scale={scale[0].flatten()[::10]}, cur_param_var={cur_param_var[0].flatten()[::10]}, S={S[0].flatten()[::10]}")
|
||||||
|
|
||||||
# scale shape: (batch_size, 1, size, 1, 1)
|
# scale shape: (batch_size, 1, size, 1, 1)
|
||||||
cur_p *= scale
|
cur_p *= scale
|
||||||
@ -654,6 +650,13 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
|||||||
# Apply the scales in `cur_scales` to Q for each dim; this reflects the
|
# Apply the scales in `cur_scales` to Q for each dim; this reflects the
|
||||||
# parameter rms values in the parameter-diagonalized space, that we have
|
# parameter rms values in the parameter-diagonalized space, that we have
|
||||||
# estimated in the loop above.
|
# estimated in the loop above.
|
||||||
|
#
|
||||||
|
# We normalize the scales in such a way that the Frobenius norm
|
||||||
|
# after projecting (grad / denom) with Q should be unchanged, i.e. the
|
||||||
|
# same as (grad / denom), which is equivalent to having rms=1.0 due
|
||||||
|
# to how denom is constructed. This simplifies the normalization of the overall
|
||||||
|
# scale of the parameter change: we just have to multiply by the learning
|
||||||
|
# rate and param_rms.
|
||||||
for dim in range(1, ndim):
|
for dim in range(1, ndim):
|
||||||
if cur_scales[dim] is not None:
|
if cur_scales[dim] is not None:
|
||||||
size = p.shape[dim]
|
size = p.shape[dim]
|
||||||
@ -661,6 +664,16 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
|||||||
(batch_size, num_blocks, block_size, block_size) = Q.shape
|
(batch_size, num_blocks, block_size, block_size) = Q.shape
|
||||||
|
|
||||||
scale = cur_scales[dim].reshape(batch_size, num_blocks, block_size, 1)
|
scale = cur_scales[dim].reshape(batch_size, num_blocks, block_size, 1)
|
||||||
|
# The following normalization step will ensure the Frobenius
|
||||||
|
# norm is unchanged, from applying this scale: at least,
|
||||||
|
# assuming "grad / denom" gives uncorrelated outputs so that
|
||||||
|
# they will have equal variances after projecting to the space
|
||||||
|
# where the parameter var is diagonalized... this is *roughly*
|
||||||
|
# true because the gradients at the point where we compute "grad
|
||||||
|
# / denom" should be decorrelated at least considering
|
||||||
|
# individual tensor dims
|
||||||
|
scale /= _mean(scale**2, exclude_dims=[0], keepdim=True).sqrt()
|
||||||
|
|
||||||
# Q is indexed [batch_index, block_index, diagonalized_coordinate, canonical_coordinate],
|
# Q is indexed [batch_index, block_index, diagonalized_coordinate, canonical_coordinate],
|
||||||
# want to multiply on the diagonalized co-ordinate.
|
# want to multiply on the diagonalized co-ordinate.
|
||||||
# else: Q is indexed [batch_index, canonical_coordinate].
|
# else: Q is indexed [batch_index, canonical_coordinate].
|
||||||
@ -825,21 +838,8 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
|||||||
# and project back..
|
# and project back..
|
||||||
grad = self._project(grad, state, forward=False)
|
grad = self._project(grad, state, forward=False)
|
||||||
|
|
||||||
scalar_exp_avg_sq = state["scalar_exp_avg_sq"]
|
|
||||||
scalar_exp_avg_sq.mul_(beta2).add_(_mean(grad**2,
|
|
||||||
exclude_dims=[0],
|
|
||||||
keepdim=True),
|
|
||||||
alpha=1-beta2)
|
|
||||||
|
|
||||||
|
alpha = -lr * (1-beta1) * state["param_rms"]
|
||||||
if bias_correction2 < 0.99:
|
|
||||||
# scalar_exp_avg_sq is also zeroed at step "zero_step", so
|
|
||||||
# use the same bias_correction2.
|
|
||||||
scalar_exp_avg_sq = scalar_exp_avg_sq * (1.0 / bias_correction2)
|
|
||||||
|
|
||||||
denom = scalar_exp_avg_sq.sqrt() + eps # scalar, or [batch_size,1,1..]
|
|
||||||
|
|
||||||
alpha = state["param_rms"] * (1-beta1) * -lr / denom
|
|
||||||
|
|
||||||
delta = state["delta"]
|
delta = state["delta"]
|
||||||
delta.add_(grad * alpha)
|
delta.add_(grad * alpha)
|
||||||
@ -1310,7 +1310,7 @@ class Cain(Optimizer):
|
|||||||
bias_correction2 = 1 - beta2 ** step
|
bias_correction2 = 1 - beta2 ** step
|
||||||
denom = (exp_avg_sq.sqrt()).add_(eps)
|
denom = (exp_avg_sq.sqrt()).add_(eps)
|
||||||
this_delta = grad / denom
|
this_delta = grad / denom
|
||||||
alpha = -lr*(1-beta1)*(bias_correction2 ** 0.5)
|
alpha = -lr * (1-beta1) * (bias_correction2 ** 0.5)
|
||||||
delta.add_(this_delta, alpha=alpha)
|
delta.add_(this_delta, alpha=alpha)
|
||||||
|
|
||||||
p.add_(delta)
|
p.add_(delta)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user