mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-18 21:44:18 +00:00
First version after refactorization and changing the math, where optim.py runs
This commit is contained in:
parent
4da4e69fba
commit
dd10eb140f
@ -142,10 +142,10 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
||||
lr=3e-02,
|
||||
betas=(0.9, 0.98),
|
||||
size_lr_scale=0.1,
|
||||
param_pow=1.0,
|
||||
min_lr_factor=(0.05, 0.05, 0.05),
|
||||
max_lr_factor=(10.0, 10.0, 10.0),
|
||||
param_rms_smooth0=0.75,
|
||||
param_rms_smooth1=0.25,
|
||||
max_lr_factor=10.0,
|
||||
eps=1.0e-08,
|
||||
param_min_rms=1.0e-05,
|
||||
param_max_rms=2.0,
|
||||
@ -153,7 +153,6 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
||||
size_update_period=4,
|
||||
lr_update_period=(200, 1000),
|
||||
grad_cov_period=3,
|
||||
param_cov_period=100,
|
||||
max_block_size=1024,
|
||||
):
|
||||
|
||||
@ -162,10 +161,10 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
||||
defaults = dict(
|
||||
lr=lr,
|
||||
size_lr_scale=size_lr_scale,
|
||||
param_pow=param_pow,
|
||||
min_lr_factor=min_lr_factor,
|
||||
max_lr_factor=max_lr_factor,
|
||||
param_rms_smooth0=param_rms_smooth0,
|
||||
param_rms_smooth1=param_rms_smooth1,
|
||||
max_lr_factor=max_lr_factor,
|
||||
betas=betas,
|
||||
eps=eps,
|
||||
param_min_rms=param_min_rms,
|
||||
@ -174,7 +173,6 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
||||
size_update_period=size_update_period,
|
||||
lr_update_period=lr_update_period,
|
||||
grad_cov_period=grad_cov_period,
|
||||
param_cov_period=param_cov_period,
|
||||
max_block_size=max_block_size,
|
||||
)
|
||||
|
||||
@ -283,7 +281,8 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
||||
return
|
||||
|
||||
# "zero_step" being a member of state is the sign that this parameter has
|
||||
# at least one dim that has a projection.
|
||||
# at least one dim that has a projection. It also records the most recent
|
||||
# step on which we zeroed state["exp_avg_sq"]
|
||||
state["zero_step"] = 0
|
||||
# last_param_scale_update records the last time we updated the part of the learning rate
|
||||
# matrices that relates to the parameter covariance; we avoid doing this too often
|
||||
@ -359,7 +358,6 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
||||
lr = group["lr"]
|
||||
size_update_period = group["size_update_period"]
|
||||
grad_cov_period = group["grad_cov_period"]
|
||||
param_cov_period = group["param_cov_period"]
|
||||
eps = group["eps"]
|
||||
beta1 = group["betas"][0]
|
||||
|
||||
@ -393,12 +391,15 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
||||
else:
|
||||
if self._is_lr_update_step(group, state):
|
||||
self._update_param_cov(group, p, state)
|
||||
if step > state["last_param_scale_update"] * 1.1 and state["last_param_scale_update"] != state["zero_step"]:
|
||||
# Only update the parameter-dependent part of the learning
|
||||
# rate matrices at most every other time we reach here, and
|
||||
# less frequently than that later in training.
|
||||
self._update_param_scales(group, p, state)
|
||||
self._diagonalize_grad_cov(group, p, state)
|
||||
|
||||
P_proj = self._compute_bases(group, p.shape, state)
|
||||
# Only update the parameter-dependent part of the learning
|
||||
# rate matrices at most every other time we reach here, and
|
||||
# less frequently than that later in training.
|
||||
self._update_param_scales(group, p, state, P_proj)
|
||||
|
||||
# We won't be doing this any more.
|
||||
#self._diagonalize_grad_cov(group, p, state)
|
||||
self._zero_exp_avg_sq(state)
|
||||
if step % grad_cov_period == 0:
|
||||
self._update_grad_cov(group, p, state)
|
||||
@ -474,7 +475,6 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
||||
(except batch and trivial and rank-1 dims)
|
||||
"""
|
||||
eps = group["eps"]
|
||||
param_cov_period = group["param_cov_period"]
|
||||
|
||||
# zero_step is always the last time we called _update_param_cov.
|
||||
# Our aim is to compute the parameter covariance averaged over all time
|
||||
@ -555,7 +555,8 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
||||
def _update_param_scales(self,
|
||||
group: dict,
|
||||
p: Tensor,
|
||||
state: dict) -> None:
|
||||
state: dict,
|
||||
P_proj: List[Optional[Tensor]]) -> None:
|
||||
"""
|
||||
Computes learning-rate matrices Q for each dim of this tensor: only the part that depends
|
||||
on the parameter covariance, we will later add a rotation that depends on the gradient
|
||||
@ -567,6 +568,10 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
||||
are actually factors of p, so p itself will change when we change
|
||||
them.
|
||||
state: state dict for the current parameter
|
||||
P_proj: a list indexed by dim, containing tensors of shape (batch_size, size)
|
||||
where size == p.shape[dim], which represent the parameter covariance
|
||||
projected by state[f"Q_{dim}"] (i.e. by the value of this variable
|
||||
at entry to this function)
|
||||
"""
|
||||
ndim = p.ndim
|
||||
batch_size = p.shape[0]
|
||||
@ -597,15 +602,7 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
||||
assert size == 1 or size == numel, size
|
||||
continue # e.g. size == 1 or size == numel
|
||||
|
||||
param_cov = self._get_smoothed_param_cov(group, p, state, dim)
|
||||
|
||||
# param_cov has the same shape as Q
|
||||
(batch_size, num_blocks, block_size, block_size) = Q.shape
|
||||
# U,S,V have extra leading dimensions, i.e. of shape
|
||||
# {U,V}.shape == (batch_size, num_blocks, block_size, block_size),
|
||||
# S.shape == (batch_size, num_blocks, block_size).
|
||||
U, S, V = _svd(param_cov)
|
||||
Q[:] = U.transpose(2, 3)
|
||||
|
||||
M = cur_p.transpose(dim, -1)
|
||||
# if p were of shape (batch_size, x, size, y, z),
|
||||
@ -615,11 +612,12 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
||||
num_blocks, block_size)
|
||||
M = _move_dim(M, -2, 1) # (batch_size, num_blocks, x, y, z, block_size)
|
||||
|
||||
while U.ndim < M.ndim:
|
||||
U = U.unsqueeze(2)
|
||||
# Now U is of shape (batch_size, num_blocks, 1, 1, block_size, block_size)
|
||||
|
||||
M = torch.matmul(M, U) # (batch_size, num_blocks, x, y, z, block_size)
|
||||
while Q.ndim < M.ndim:
|
||||
Q = Q.unsqueeze(2)
|
||||
# Now Q is of shape (batch_size, num_blocks, 1, 1, block_size, block_size)
|
||||
# [batch_index, block_index, diagonalized_coordinate, canonical_coordinate],
|
||||
# so we need to transpose Q as we convert M to the diagonalized co-ordinate.
|
||||
M = torch.matmul(M, Q.transpose(2, 3)) # (batch_size, num_blocks, x, y, z, block_size)
|
||||
M = _move_dim(M, 1, -2) # (batch_size, x, y, z, num_blocks, block_size)
|
||||
M = M.reshape(*M.shape[:-2], size) # # (batch_size, x, y, z, size)
|
||||
cur_p = M.transpose(dim, -1) # (batch_size, x, size, y, z)
|
||||
@ -630,7 +628,8 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
||||
cur_param_var = _mean(cur_p**2,
|
||||
exclude_dims=[0,dim],
|
||||
keepdim=True) # (batch_size, 1, size, 1, 1) if dim==2
|
||||
S = S.reshape(cur_param_var.shape) # (batch_size, 1, size, 1, 1) if dim==2
|
||||
smoothed_param_var = P_proj[dim] # (batch_size, size)
|
||||
S = smoothed_param_var.reshape(cur_param_var.shape) # (batch_size, 1, size, 1, 1) if dim==2
|
||||
|
||||
# OK, cur_param_var would have the values as S if the variance stats
|
||||
# param_cov_{dim} were accumulated from this exact parameter matrix,
|
||||
@ -639,7 +638,7 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
||||
# spectrum"). We scale p so that it matches the accumulated stats,
|
||||
# the idea is to ensure it doesn't have any too-small eigenvalues
|
||||
# (where the stats permit).
|
||||
scale = (S.clamp(min=eps) / cur_param_var.clamp(min=eps)).sqrt()
|
||||
scale = (S / cur_param_var.clamp(min=eps)).sqrt()
|
||||
|
||||
if random.random() < 0.01:
|
||||
skip = 10 if size < 20 else 1
|
||||
@ -684,18 +683,16 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
||||
# rms shape: (batch_size, 1, size, 1, 1)
|
||||
rms = _mean(cur_p**2, exclude_dims=[0,dim], keepdim=True).sqrt()
|
||||
rank = numel // size
|
||||
# we did other kinds of smoothing in _get_smoothed_param_cov
|
||||
#smoothed_rms = self._smooth_param_rms(group, rms, rank)
|
||||
smoothed_rms = rms ** group["param_pow"]
|
||||
cur_scales[dim] = smoothed_rms
|
||||
cur_p /= smoothed_rms # normalize/"whiten" cur_p on this dim..
|
||||
# TODO: consider more smoothing here???
|
||||
cur_scales[dim] = rms
|
||||
cur_p /= rms
|
||||
|
||||
if debug:
|
||||
def _summarize(rms):
|
||||
rms = rms[0] # get rid of batch dim by selecting one example
|
||||
rms = rms.flatten()
|
||||
return rms[::10] # subset one every ten items
|
||||
logging.info(f"i={i} shape={tuple(p.shape)}, dim={dim}, rank={rank}, size={size}, rms={_summarize(rms)}, smoothed_rms={_summarize(smoothed_rms)}")
|
||||
logging.info(f"i={i} shape={tuple(p.shape)}, dim={dim}, rank={rank}, size={size}, rms={_summarize(rms)}")
|
||||
|
||||
# Apply the scales in `cur_scales` to Q for each dim; this reflects the
|
||||
# parameter rms values in the parameter-diagonalized space, that we have
|
||||
@ -731,92 +728,272 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
||||
state["last_param_scale_update"] = state["step"]
|
||||
|
||||
|
||||
def _get_smoothed_param_cov(self,
|
||||
group: dict,
|
||||
p: Tensor,
|
||||
state: dict,
|
||||
dim: int) -> Tensor:
|
||||
def _compute_bases(self,
|
||||
group: dict,
|
||||
p_shape: torch.Size,
|
||||
state: dict) -> List[Optional[Tensor]]:
|
||||
"""
|
||||
For each tensor dimension that we are preconditioning, this function sets
|
||||
state[f"Q_{dim}"] to an orthogonal matrix (per batch element and
|
||||
block); it will be indexed [batch_index, block_index, diagonalized_coordinate, canonical_coordinate].
|
||||
This will be the matrix that diagonalizes Z,
|
||||
where Z is the symmetric matrix satisfying:
|
||||
Z G Z = P (eqn:1)
|
||||
with G being the gradient covariance grad_cov, P being the smoothed version of the parameter
|
||||
covariance param_cov, and Z the symmetric positive definit learning-rate matrices.
|
||||
We'll discuss later how we solve for Z. Firstly, because we want
|
||||
a basis that is as meaningful possible for smoothing P, we will first put P and G in a basis
|
||||
where G is diagonal. That is, we do an SVD G = U G' U^T, and multiplying (eqn:1) on the
|
||||
left and right by U^T and U respetively, and inserting some expressions equivalent to I,
|
||||
we have:
|
||||
(U^T Z U) (U^T G U) (U^T Z U) = (U^T P U)
|
||||
or: Z' G' Z' = P' (eqn:2)
|
||||
where Z' = U^T Z U and P' = U^T P U, and of course G' is diagonal. We are actually going to
|
||||
be solving (eqn:2), and then computing:
|
||||
Z = U Z' U^T.
|
||||
|
||||
A solution to (eqn:1) is as follows. We are going to be using a Cholesky-based solution in
|
||||
favor of one that requires SVD or eigenvalue decomposition, because it is much faster (we first
|
||||
have to be careful that the input is not close to singular, though).
|
||||
|
||||
So, with the Cholesky decomposition of P' being:
|
||||
P' = C C^T,
|
||||
the solution is going to be:
|
||||
Z' = C (C^T G' C)^{-0.5} C^T (eqn:3)
|
||||
[[ we can verify that this is a solution by multiplying (eqn:1) by C^{-1} and C^{-T} on the
|
||||
left and right respectively, giving:
|
||||
(C^T G' C)^{-0.5} C^T G' C (C^T G' C)^{-0.5} = I
|
||||
which we can immediately see is the case. ]]
|
||||
|
||||
Args:
|
||||
group: dict to look up config values
|
||||
p_shape: the shape of the batch of identical-sized tensors we are optimizing
|
||||
state: dict to look up optimization state
|
||||
|
||||
Return:
|
||||
This function returns a list of Tensors P_proj indexed by dim from 0..ndim-1, the tensors being
|
||||
of shape (batch_size, size), which contain the diagonal of the smoothed parameter covariance
|
||||
P after projection by the orthogonal matrix state[f"Q_{dim}"] that diagonalizes Z (this is
|
||||
the space in which we do the actual update).
|
||||
"""
|
||||
ndim = len(p_shape)
|
||||
P_proj = [None] * ndim
|
||||
for dim in range(1, ndim):
|
||||
size = p_shape[dim]
|
||||
try:
|
||||
Q = state[f"Q_{dim}"]
|
||||
except KeyError:
|
||||
assert size == 1 or size == numel, size
|
||||
continue # e.g. size == 1 or size == numel
|
||||
|
||||
|
||||
param_cov = state[f"param_cov_{dim}"] # (batch_size, num_blocks, block_size, block_size)
|
||||
grad_cov = state[f"grad_cov_{dim}"] # (batch_size, num_blocks, block_size, block_size)
|
||||
(batch_size, num_blocks, block_size, block_size) = param_cov.shape
|
||||
# U_g diagonalizes grad_cov, in the sense that U_g^T grad_cov U_g is diagonal; its shape
|
||||
# is the same as param_cov and grad_cov.
|
||||
#
|
||||
# G_prime is the diagonalized grad_cov, G' above, of shape (batch_size, num_blocks, block_size).
|
||||
U_g, G_prime, _ = _svd(grad_cov)
|
||||
# P_prime is P' above, which represents param_cov in the basis that diagonalizes G_prime.
|
||||
# It is not smoothed yet.
|
||||
P_prime = torch.matmul(U_g.transpose(2, 3), torch.matmul(param_cov, U_g))
|
||||
|
||||
P_prime = self._smooth_param_cov(group, p_shape, P_prime, G_prime)
|
||||
|
||||
C = P_prime.cholesky() # P_prime = torch.matmul(C, C.transpose(2, 3))
|
||||
|
||||
# CGC = (C^T G' C) which would be torch.matmul(C.transpose(2, 3), torch.matmul(G, C))
|
||||
# if G were formatted as a diagonal matrix, but G is just the diagonal.
|
||||
CGC = torch.matmul(C.transpose(2, 3) * G_prime.unsqueeze(-2),
|
||||
C)
|
||||
U, S, _ = _svd(CGC) # Need SVD to compute CGC^{-0.5}
|
||||
# next we compute (eqn:3). The thing in the parenthesis is, GCC^{-0.5},
|
||||
# can be written as U S^{-0.5} U^T, so the whole thing is
|
||||
# (C U) S^{-0.5} (C U)^T
|
||||
CU = torch.matmul(C, U)
|
||||
S_inv_sqrt = 1.0 / S.sqrt()
|
||||
Z_prime = torch.matmul(CU * S_inv_sqrt.unsqueeze(-2),
|
||||
CU.transpose(2, 3))
|
||||
|
||||
if True:
|
||||
# A check.
|
||||
# Z_prime is supposed to satisfy: Z_prime G_prime Z_prime = P_prime (eqn:2)
|
||||
P_prime_check = torch.matmul(Z_prime * G_prime.unsqueeze(-2), Z_prime)
|
||||
diff_ratio = (P_prime - P_prime_check).abs().sum() / P_prime.abs().sum()
|
||||
if diff_ratio > 0.01:
|
||||
logging.warn(f"Z_prime does not satisfy its definition, diff_ratio = {diff_ratio}")
|
||||
Z = torch.matmul(U_g, torch.matmul(Z_prime, U_g.transpose(2, 3)))
|
||||
# OK, Z is the SPD transform that maps G to P, as in Z G Z = P.
|
||||
# We just need the basis that diagonalizes this.
|
||||
U_z, S, _ = _svd(Z)
|
||||
if True:
|
||||
skip = 10 if S.shape[-1] > 40 else 1
|
||||
logging.info(f"Eigs of Z are: {S[0,0,::skip]}")
|
||||
|
||||
# state[f"Q_{dim}"] is indexed: [batch_idx, block_idx, diagonalized_coordinate, canonical_coordinate].
|
||||
# so we need to transpose U_z as U_z is indexed
|
||||
# [batch_idx, block_idx, canonical_coordinate, diagonalized_coordinate]
|
||||
Q[:] = U_z.transpose(2, 3)
|
||||
|
||||
# Work out the projected smoothed parameter covariance P_proj, which is P
|
||||
# projected in the basis U_z. Now,
|
||||
# P = U_g P' U_g^T,
|
||||
# and P_proj = U_z^T P U_z,
|
||||
# so P_proj = (U_z^T U_g) P' (U_z^T U_g)^T
|
||||
U_prod = torch.matmul(U_z.transpose(2, 3), U_g)
|
||||
# this_P_proj shape: (batch_size, num_blocks, block_size)
|
||||
this_P_proj = _diag(torch.matmul(U_prod, torch.matmul(P_prime, U_prod.transpose(2, 3))))
|
||||
P_proj[dim] = this_P_proj.reshape(batch_size, size)
|
||||
return P_proj
|
||||
|
||||
|
||||
def _smooth_param_cov(self,
|
||||
group: dict,
|
||||
p_shape: torch.Size,
|
||||
P_prime: Tensor,
|
||||
G_prime: Tensor) -> Tensor:
|
||||
"""
|
||||
This function returns a modified/smoothed version of the parameter covariance
|
||||
P_prime.
|
||||
Args:
|
||||
group: dict to look up config values
|
||||
p_shape: The shape of the parameter we are optimizing
|
||||
P_prime: a Tensor of shape (batch_size, num_blocks, block_size, block_size),
|
||||
containing the parameter covariance in a basis that diagonalizes the
|
||||
gradient covariance.
|
||||
G_prime: the diagonalized gradient covariance, of shape (batch_size, num_blocks,
|
||||
block_size)
|
||||
|
||||
|
||||
state[f"param_cov_{dim}"], which is an estimate of the covariance of the parameter
|
||||
p, averaged over time, and taken over dimension `dim` of the tensor.
|
||||
|
||||
The smoothing done here limits the extend to which the parameter covariance
|
||||
The smoothing done here limits the extent to which the parameter covariance
|
||||
can be strongly "off-diagonal" with respect to the gradient covariance. That is:
|
||||
if the parameter covariance is just the gradient covariance to some power, this
|
||||
function does no smoothing; but if it is highly off-diagonal we do more smoothing.
|
||||
"""
|
||||
param_cov = state[f"param_cov_{dim}"] # (batch_size, num_blocks, block_size, block_size)
|
||||
grad_cov = state[f"grad_cov_{dim}"] # (batch_size, num_blocks, block_size, block_size)
|
||||
(batch_size, num_blocks, block_size, block_size) = param_cov.shape
|
||||
U_g, _, _ = _svd(grad_cov) # U_g diagonalizes grad_cov, in the sense that U_g^T grad_cov U_g is diagonal.
|
||||
P_prime_diag = _diag(P_prime) # (batch_size, num_blocks, block_size)
|
||||
eps = 1.0e-10
|
||||
P_prime_diag = (P_prime_diag + eps) / P_prime_diag.mean()
|
||||
# make sure no diagonal element is close to zero.. we don't expect this
|
||||
# would happen. this is likely not important. Note, this just used for
|
||||
# normalizing P prior to smoothing.
|
||||
P_prime_diag.clamp_(min=0.01)
|
||||
P_prime_rms = P_prime_diag.sqrt()
|
||||
P_prime_scale = P_prime_rms.unsqueeze(-1) * P_prime_rms.unsqueeze(-2)
|
||||
|
||||
# param_cov_proj is param_cov in a different orthonormal basis, that diagonalizes
|
||||
# grad_cov.
|
||||
param_cov_proj = torch.matmul(U_g.transpose(2, 3), torch.matmul(param_cov, U_g))
|
||||
# P_norm will have diagonal elements close to 1. We do some smoothing
|
||||
# in this space.
|
||||
P_norm = P_prime / P_prime_scale
|
||||
# Now P is as normalized as we can make it... do smoothing baserd on 'rank',
|
||||
# that is intended to compensate for bad estimates of P.
|
||||
batch_size = p_shape[0]
|
||||
size = P_prime.shape[0] # size of dim we are concerned with right now
|
||||
# `rank` is the rank of P_prime if we were to estimate it from just one
|
||||
# parameter tensor. We average it over time, but actually it won't be changing
|
||||
# too much, so `rank` does tell us something.
|
||||
rank = p_shape.numel() // (size * batch_size)
|
||||
smooth0 = group["param_rms_smooth0"]
|
||||
smooth1 = group["param_rms_smooth1"]
|
||||
# We want expr for smoothing amount to be of the form: smooth = alpha * size / (beta*rank + size)
|
||||
# param_rms_smooth{0,1} represents the user-specified desired amount of smoothing
|
||||
# when rank==0*size and rank==1*size, respectively.
|
||||
# from rank==0*size, we get smooth0 = alpha * size/size, so alpha = smooth0.
|
||||
# from setting rank==size, we get smooth1 = alpha * size / (beta*size * size) = alpha/(1+beta),
|
||||
# so smooth1 == smooth0 / (1+beta), so (1+beta) = smooth0/smooth1, so beta=smooth0/smooth1 - 1
|
||||
smooth = smooth0 * size / ((smooth0/smooth1 - 1) * rank + size)
|
||||
|
||||
# param_cov_eps is probably not critical, I don't expect to see super
|
||||
# small values. apply as floor in case roundoff causes negative values.
|
||||
param_cov_eps = 1.0e-05
|
||||
param_rms = _diag(param_cov_proj).clamp_(min=param_cov_eps).sqrt()
|
||||
param_cov_inv_scale = param_rms.unsqueeze(-1) * param_rms.unsqueeze(-2)
|
||||
# add rank-dependent smoothing amount to diagonal of P_prime. _diag() returns an aliased tensor.
|
||||
# we don't need to multiply `smooth` by anything, because at this point, P_prime should have
|
||||
# diagonal elements close to 1.
|
||||
_diag(P_prime).add_(smooth)
|
||||
|
||||
# param_cov_norm should have diagonal values close to 1.0 (only not
|
||||
# exactly 1.0 due to param_cov_eps and roundoff)
|
||||
param_cov_norm = param_cov_proj / param_cov_inv_scale
|
||||
P_norm = self._smooth_cov(P_norm,
|
||||
group["min_lr_factor"][0],
|
||||
group["max_lr_factor"][0])
|
||||
# Remove the diagonal preconditioning on P_norm, giving us stage-1-smoothed
|
||||
# version of P_prime.
|
||||
P_prime = P_norm * P_prime_scale
|
||||
|
||||
# OK, this is where we do smoothing.
|
||||
# decompose param_cov_norm, which is symmetric, as U_p S U_p^T
|
||||
U_p, S, _ = _svd(param_cov_norm)
|
||||
# Make sure G_prime has unit mean and no eigenvalue is super small. Note, G_prime
|
||||
# is already diagonal.
|
||||
G_prime_mean = _mean(G_prime, exclude_dims=[0], keepdim=True)
|
||||
G_prime_smooth = 0.001
|
||||
# make sure G_prime has no zero eigs, and is unit mean.
|
||||
G_prime = ((G_prime + eps + G_prime_smooth * G_prime_mean) /
|
||||
(G_prime_mean * (1+G_prime_smooth) + eps))
|
||||
G_prime_rms = G_prime.sqrt()
|
||||
G_prime_scale = G_prime_rms.unsqueeze(-1) * G_prime_rms.unsqueeze(-2)
|
||||
# P_gnorm is a version of P_prime that is scaled relative to G, i.e.
|
||||
# scaled in such a way that would make G the unit matrix.
|
||||
P_gnorm = P_prime / G_prime_scale
|
||||
# Apply another round of smoothing "relative to G"
|
||||
P_gnorm = self._smooth_cov(P_gnorm,
|
||||
group["min_lr_factor"][1],
|
||||
group["max_lr_factor"][1])
|
||||
# Undo the scaling relative to G, so we have stage-2-smoothed version of P_prime.
|
||||
P_prime = P_gnorm * G_prime_scale
|
||||
|
||||
# Apply a 3rd round of smoothing
|
||||
P_prime = self._smooth_cov(P_prime,
|
||||
group["min_lr_factor"][2],
|
||||
group["max_lr_factor"][2])
|
||||
return P_prime
|
||||
|
||||
def _smooth_cov(self,
|
||||
X: Tensor,
|
||||
min_eig: float,
|
||||
max_eig: float,
|
||||
power: float = 1.0) -> Tensor:
|
||||
"""
|
||||
Returns a `smoothed` version of a symmetric positive definite covariance matrix
|
||||
[with block-diagonal structure, in a batch]. This is done without SVD (which
|
||||
can be very slow).
|
||||
The eigenvalues L will be transformed as:
|
||||
|
||||
residual_rms = S.sqrt()
|
||||
#
|
||||
relative_rms_pow = 0.7
|
||||
relative_rms_max = 4.0
|
||||
L = L + min_eig * L.mean() + eps
|
||||
L /= L.mean()
|
||||
L = 1 / (1/L + 1/max_eig) # soft-min between L and max_eig
|
||||
L = L ** power # need SVD for this, will get rid of the requirement later.
|
||||
L /= L.mean()
|
||||
|
||||
residual_rms = residual_rms ** relative_rms_pow
|
||||
residual_rms /= _mean(residual_rms, exclude_dims=[0], keepdim=True)
|
||||
# Note on approximation functions like x^0.75 for smallish x: on wolframalpha, type:
|
||||
# plot x^0.75 and 0.05 + (1.1x - 0.18 x^2 + 0.02 x^3) for x from 0 to 10
|
||||
# [this starts to diverge after 5 or so]
|
||||
|
||||
if True:
|
||||
# smooth according to the rank of the observation..
|
||||
size = p.shape[dim]
|
||||
rank = p.numel() // (size * batch_size)
|
||||
smooth0 = group["param_rms_smooth0"]
|
||||
smooth1 = group["param_rms_smooth1"]
|
||||
# want expr to be of the form: smooth = alpha * size / (beta*rank + size)
|
||||
# from rank==0, we get smooth0 = alpha * size/size, so alpha = smooth0.
|
||||
# from setting rank==size, we get smooth1 = alpha * size / (beta*size * size) = alpha/(1+beta),
|
||||
# so smooth1 == smooth0 / (1+beta), so (1+beta) = smooth0/smooth1, so beta=smooth0/smooth1 - 1
|
||||
smooth = smooth0 * size / ((smooth0/smooth1 - 1) * rank + size)
|
||||
|
||||
mean = _mean(residual_rms, exclude_dims=[0], keepdim=True)
|
||||
residual_rms += group["eps"] + smooth * mean
|
||||
residual_rms = residual_rms / _mean(residual_rms, exclude_dims=[0], keepdim=True)
|
||||
|
||||
|
||||
# apply the maximum via a softmin function, softmin(x,y) = 1/(1/x + 1/y)
|
||||
residual_rms = 1. / (1. / residual_rms + 1. / relative_rms_max)
|
||||
|
||||
if random.random() < 0.1:
|
||||
skip = 10 if S.shape[-1] > 40 else 1
|
||||
logging.info(f"Smoothed param_rms from {S.sqrt()[0,0,::skip]} to {residual_rms[0,0,::skip]}, param_rms={param_rms[0,0,::skip]}")
|
||||
|
||||
# U shape: (batch_size, num_blocks, block_size, block_size),
|
||||
# interpreted as
|
||||
# residual_rms shape: (batch_size, num_blocks, block_size).
|
||||
# so in terms of matrix multiplication, we are computing X_p = matmul(U_p, residual_rms.diag())
|
||||
X_p = U_p * residual_rms.unsqueeze(-2)
|
||||
param_cov_norm_smoothed = torch.matmul(X_p, X_p.transpose(2, 3))
|
||||
|
||||
# Undo the scaling by the diagonal of param_cov
|
||||
param_cov_proj_smoothed = param_cov_norm_smoothed * param_cov_inv_scale
|
||||
|
||||
# Undo the projection by U.
|
||||
param_cov_smoothed = torch.matmul(U_g, torch.matmul(param_cov_proj_smoothed,
|
||||
U_g.transpose(2, 3)))
|
||||
return param_cov_smoothed
|
||||
Args:
|
||||
min_eig: minimum allowed eigenvalue of returned X
|
||||
max_eig: maximum allowed eigenvalue of returned X
|
||||
power: power to take eigenvalues to
|
||||
X: the batch of symmetric positive definite tensors we are smoothing;
|
||||
of shape (batch_size, num_blocks, block_size, block_size)
|
||||
"""
|
||||
if power != 1.0:
|
||||
U, S, _ = _svd(X)
|
||||
eps = 1.0e-10
|
||||
S_mean = _mean(S, exclude_dims=[0], keepdim=True)
|
||||
S = S + min_eig * S_mean + eps
|
||||
S_mean = S_mean * (1 + min_eig) + eps
|
||||
S = S / S_mean
|
||||
S = 1. / (1./S + 1./max_eig)
|
||||
S = S ** power
|
||||
S = S / _mean(S, exclude_dims=[0], keepdim=True)
|
||||
return torch.matmul(U * S.unsqueeze(-2), U.transpose(2, 3))
|
||||
else:
|
||||
diag = _diag(X) # Aliased with X
|
||||
mean_eig = _mean(diag, exclude_dims=[0], keepdim=True)
|
||||
eps = 1.0e-10 # prevent division by zero
|
||||
diag += (mean_eig * min_eig + eps)
|
||||
cur_diag_mean = mean_eig * (1 + min_eig) + eps
|
||||
# The following 2 statements will be equivalent to:
|
||||
# L /= L.mean()
|
||||
# L = 1 / (1/L + 1/max_eig) # soft-min between L and max_eig
|
||||
# if L is the eigenvalues
|
||||
X = (X.inverse() + 1/(max_eig * cur_diag_mean.unsqueeze(-1))).inverse()
|
||||
X /= _mean(_diag(X), exclude_dims=[0], keepdim=True).unsqueeze(-1)
|
||||
return X
|
||||
|
||||
|
||||
def _diagonalize_grad_cov(self,
|
||||
@ -1144,22 +1321,27 @@ def _move_dim(x: Tensor, orig_dim: int, new_dim: int) -> Tensor:
|
||||
|
||||
def _diag(x: Tensor):
|
||||
"""
|
||||
like torch diag(), but supports batch dim, i.e. input of shape (B, M, M) returns
|
||||
like torch diag(), this returns the diagonal of a matrix; but it returns an
|
||||
aliased tensor, and it supports batch dims,
|
||||
i.e. input of shape (B, M, M) returns
|
||||
output of shape (B, M), or input of shape (A, B, M, M) returns output of shape
|
||||
(A, B, M)
|
||||
(A, B, M).
|
||||
"""
|
||||
stride = x.stride()
|
||||
if x.ndim == 3:
|
||||
(B, M, M2) = x.shape
|
||||
assert M == M2
|
||||
stride = x.stride()
|
||||
return x.as_strided(size=(B, M), stride=(stride[0], stride[1] + stride[2])).contiguous()
|
||||
ans = x.as_strided(size=(B, M), stride=(stride[0], stride[1] + stride[2]))
|
||||
elif x.ndim == 4:
|
||||
(B, C, M, M2) = x.shape
|
||||
assert M == M2
|
||||
stride = x.stride()
|
||||
return x.as_strided(size=(B, C, M), stride=(stride[0], stride[1], stride[2] + stride[3])).contiguous()
|
||||
else:
|
||||
return x.diag()
|
||||
ans = x.as_strided(size=(B, C, M), stride=(stride[0], stride[1], stride[2] + stride[3])).contiguous()
|
||||
elif x.ndim == 2:
|
||||
(M, M2) = x.shape
|
||||
assert M == M2
|
||||
ans = x.as_strided(size=(M,), stride=(stride[0] + stride[1],)).contiguous()
|
||||
return ans
|
||||
|
||||
|
||||
def _sum(x: Tensor,
|
||||
exclude_dims: List[int] = [],
|
||||
|
Loading…
x
Reference in New Issue
Block a user