mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Add _update_param_scales_simple(), add documentation
This commit is contained in:
parent
9730352257
commit
2c4bdd0ad0
@ -555,15 +555,42 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
||||
return (step >= zero_step + cur_update_period)
|
||||
|
||||
|
||||
def _update_param_scales_simple(self,
|
||||
group: dict,
|
||||
p: Tensor,
|
||||
state: dict,
|
||||
P_proj: List[Optional[Tensor]]) -> None:
|
||||
for dim in range(1, p.ndim):
|
||||
size = p.shape[dim]
|
||||
try:
|
||||
Q = state[f"Q_{dim}"]
|
||||
except KeyError:
|
||||
assert size == 1 or size == numel, size
|
||||
continue # e.g. size == 1 or size == numel
|
||||
|
||||
(batch_size, num_blocks, block_size, block_size) = Q.shape
|
||||
this_P_proj = P_proj[dim].reshape(batch_size, num_blocks, block_size, 1)
|
||||
|
||||
# The following normalization step will ensure the Frobenius
|
||||
# norm is unchanged, from applying this scale: at least,
|
||||
# assuming "grad / denom" gives uncorrelated outputs so that
|
||||
# they will have equal variances after projecting to the space
|
||||
# where the parameter var is diagonalized... this is *roughly*
|
||||
# true because the gradients at the point where we compute "grad
|
||||
# / denom" should be decorrelated at least considering
|
||||
# individual tensor dims
|
||||
this_P_proj /= _mean(this_P_proj, exclude_dims=[0], keepdim=True)
|
||||
|
||||
Q *= this_P_proj.sqrt()
|
||||
|
||||
def _update_param_scales(self,
|
||||
group: dict,
|
||||
p: Tensor,
|
||||
state: dict,
|
||||
P_proj: List[Optional[Tensor]]) -> None:
|
||||
"""
|
||||
Computes learning-rate matrices Q for each dim of this tensor: only the part that depends
|
||||
on the parameter covariance, we will later add a rotation that depends on the gradient
|
||||
covariance.
|
||||
Modifies the scales on the rows of the learning-rate matrices Q for each dim of this tensor,
|
||||
to take into account the estimated parameter covariance.
|
||||
|
||||
Args:
|
||||
group: dict to look up configuration values
|
||||
@ -636,13 +663,19 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
||||
S = smoothed_param_var.reshape(cur_param_var.shape) # (batch_size, 1, size, 1, 1) if dim==2
|
||||
|
||||
# OK, cur_param_var would have the values as S if the variance stats
|
||||
# param_cov_{dim} were accumulated from this exact parameter matrix,
|
||||
# but actually they contain older versions of the parameter
|
||||
# covariance so they will, in general, be less extreme ("whiter
|
||||
# spectrum"). We scale p so that it matches the accumulated stats,
|
||||
# the idea is to ensure it doesn't have any too-small eigenvalues
|
||||
# (where the stats permit).
|
||||
|
||||
# P_proj[dim] were accumulated from this exact parameter matrix and
|
||||
# not smoothed, but actually they contain older versions of the
|
||||
# parameter covariance and they have been smoothed, so they will, in
|
||||
# general, be less extreme ("whiter spectrum"). We scale p so that
|
||||
# it matches the estimated variance P_proj[dim]; the idea is to ensure it doesn't
|
||||
# have too-extreme eigenvalues (where the stats permit).
|
||||
# Actually we could just use P_proj[dim].sqrt(), suitably scaled,
|
||||
# as the scales on the rows of Q (see _update_param_scales_simple() which does
|
||||
# exactly this), but there is a problem of "counting things twice"
|
||||
# which is easiest to understand for a 2-dimensional tensor, where the
|
||||
# singular values show up identically in the covariance over either axis.
|
||||
# The estimation procedure in this function avoids the "counting things twice"
|
||||
# problem, at the expense of quite a bit of extra complexity.
|
||||
scale = (S / cur_param_var.clamp(min=eps)).sqrt()
|
||||
|
||||
if True:
|
||||
@ -650,7 +683,7 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
||||
cur_tmp = cur_param_var.reshape(batch_size, size)
|
||||
scale_tmp = scale.reshape(batch_size, size)
|
||||
skip = 10 if size > 40 else 1
|
||||
logging.info(f"cur_param_var={cur_tmp[0][::skip]}, S={S_tmp[0][::skip]}, scale={scale_tmp[0][::skip]}")
|
||||
logging.info(f"dim={dim}/{ndim}, cur_param_var={cur_tmp[0][::skip]}, S={S_tmp[0][::skip]}, scale={scale_tmp[0][::skip]}")
|
||||
|
||||
|
||||
if random.random() < 0.01:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user