Add _update_param_scales_simple(), add documentation

This commit is contained in:
Daniel Povey 2022-07-23 14:49:58 +08:00
parent 9730352257
commit 2c4bdd0ad0

View File

@ -555,15 +555,42 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
return (step >= zero_step + cur_update_period)
def _update_param_scales_simple(self,
group: dict,
p: Tensor,
state: dict,
P_proj: List[Optional[Tensor]]) -> None:
for dim in range(1, p.ndim):
size = p.shape[dim]
try:
Q = state[f"Q_{dim}"]
except KeyError:
assert size == 1 or size == numel, size
continue # e.g. size == 1 or size == numel
(batch_size, num_blocks, block_size, block_size) = Q.shape
this_P_proj = P_proj[dim].reshape(batch_size, num_blocks, block_size, 1)
# The following normalization step will ensure the Frobenius
# norm is unchanged, from applying this scale: at least,
# assuming "grad / denom" gives uncorrelated outputs so that
# they will have equal variances after projecting to the space
# where the parameter var is diagonalized... this is *roughly*
# true because the gradients at the point where we compute "grad
# / denom" should be decorrelated at least considering
# individual tensor dims
this_P_proj /= _mean(this_P_proj, exclude_dims=[0], keepdim=True)
Q *= this_P_proj.sqrt()
def _update_param_scales(self,
group: dict,
p: Tensor,
state: dict,
P_proj: List[Optional[Tensor]]) -> None:
"""
Computes learning-rate matrices Q for each dim of this tensor: only the part that depends
on the parameter covariance, we will later add a rotation that depends on the gradient
covariance.
Modifies the scales on the rows of the learning-rate matrices Q for each dim of this tensor,
to take into account the estimated parameter covariance.
Args:
group: dict to look up configuration values
@ -636,13 +663,19 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
S = smoothed_param_var.reshape(cur_param_var.shape) # (batch_size, 1, size, 1, 1) if dim==2
# OK, cur_param_var would have the values as S if the variance stats
# param_cov_{dim} were accumulated from this exact parameter matrix,
# but actually they contain older versions of the parameter
# covariance so they will, in general, be less extreme ("whiter
# spectrum"). We scale p so that it matches the accumulated stats,
# the idea is to ensure it doesn't have any too-small eigenvalues
# (where the stats permit).
# P_proj[dim] were accumulated from this exact parameter matrix and
# not smoothed, but actually they contain older versions of the
# parameter covariance and they have been smoothed, so they will, in
# general, be less extreme ("whiter spectrum"). We scale p so that
# it matches the estimated variance P_proj[dim]; the idea is to ensure it doesn't
# have too-extreme eigenvalues (where the stats permit).
# Actually we could just use P_proj[dim].sqrt(), suitably scaled,
# as the scales on the rows of Q (see _update_param_scales_simple() which does
# exactly this), but there is a problem of "counting things twice"
# which is easiest to understand for a 2-dimensional tensor, where the
# singular values show up identically in the covariance over either axis.
# The estimation procedure in this function avoids the "counting things twice"
# problem, at the expense of quite a bit of extra complexity.
scale = (S / cur_param_var.clamp(min=eps)).sqrt()
if True:
@ -650,7 +683,7 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
cur_tmp = cur_param_var.reshape(batch_size, size)
scale_tmp = scale.reshape(batch_size, size)
skip = 10 if size > 40 else 1
logging.info(f"cur_param_var={cur_tmp[0][::skip]}, S={S_tmp[0][::skip]}, scale={scale_tmp[0][::skip]}")
logging.info(f"dim={dim}/{ndim}, cur_param_var={cur_tmp[0][::skip]}, S={S_tmp[0][::skip]}, scale={scale_tmp[0][::skip]}")
if random.random() < 0.01: