First version after refactorization and changing the math, where optim.py runs

This commit is contained in:
Daniel Povey 2022-07-23 06:32:56 +08:00
parent 4da4e69fba
commit dd10eb140f

View File

@ -142,10 +142,10 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
lr=3e-02, lr=3e-02,
betas=(0.9, 0.98), betas=(0.9, 0.98),
size_lr_scale=0.1, size_lr_scale=0.1,
param_pow=1.0, min_lr_factor=(0.05, 0.05, 0.05),
max_lr_factor=(10.0, 10.0, 10.0),
param_rms_smooth0=0.75, param_rms_smooth0=0.75,
param_rms_smooth1=0.25, param_rms_smooth1=0.25,
max_lr_factor=10.0,
eps=1.0e-08, eps=1.0e-08,
param_min_rms=1.0e-05, param_min_rms=1.0e-05,
param_max_rms=2.0, param_max_rms=2.0,
@ -153,7 +153,6 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
size_update_period=4, size_update_period=4,
lr_update_period=(200, 1000), lr_update_period=(200, 1000),
grad_cov_period=3, grad_cov_period=3,
param_cov_period=100,
max_block_size=1024, max_block_size=1024,
): ):
@ -162,10 +161,10 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
defaults = dict( defaults = dict(
lr=lr, lr=lr,
size_lr_scale=size_lr_scale, size_lr_scale=size_lr_scale,
param_pow=param_pow, min_lr_factor=min_lr_factor,
max_lr_factor=max_lr_factor,
param_rms_smooth0=param_rms_smooth0, param_rms_smooth0=param_rms_smooth0,
param_rms_smooth1=param_rms_smooth1, param_rms_smooth1=param_rms_smooth1,
max_lr_factor=max_lr_factor,
betas=betas, betas=betas,
eps=eps, eps=eps,
param_min_rms=param_min_rms, param_min_rms=param_min_rms,
@ -174,7 +173,6 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
size_update_period=size_update_period, size_update_period=size_update_period,
lr_update_period=lr_update_period, lr_update_period=lr_update_period,
grad_cov_period=grad_cov_period, grad_cov_period=grad_cov_period,
param_cov_period=param_cov_period,
max_block_size=max_block_size, max_block_size=max_block_size,
) )
@ -283,7 +281,8 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
return return
# "zero_step" being a member of state is the sign that this parameter has # "zero_step" being a member of state is the sign that this parameter has
# at least one dim that has a projection. # at least one dim that has a projection. It also records the most recent
# step on which we zeroed state["exp_avg_sq"]
state["zero_step"] = 0 state["zero_step"] = 0
# last_param_scale_update records the last time we updated the part of the learning rate # last_param_scale_update records the last time we updated the part of the learning rate
# matrices that relates to the parameter covariance; we avoid doing this too often # matrices that relates to the parameter covariance; we avoid doing this too often
@ -359,7 +358,6 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
lr = group["lr"] lr = group["lr"]
size_update_period = group["size_update_period"] size_update_period = group["size_update_period"]
grad_cov_period = group["grad_cov_period"] grad_cov_period = group["grad_cov_period"]
param_cov_period = group["param_cov_period"]
eps = group["eps"] eps = group["eps"]
beta1 = group["betas"][0] beta1 = group["betas"][0]
@ -393,12 +391,15 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
else: else:
if self._is_lr_update_step(group, state): if self._is_lr_update_step(group, state):
self._update_param_cov(group, p, state) self._update_param_cov(group, p, state)
if step > state["last_param_scale_update"] * 1.1 and state["last_param_scale_update"] != state["zero_step"]:
P_proj = self._compute_bases(group, p.shape, state)
# Only update the parameter-dependent part of the learning # Only update the parameter-dependent part of the learning
# rate matrices at most every other time we reach here, and # rate matrices at most every other time we reach here, and
# less frequently than that later in training. # less frequently than that later in training.
self._update_param_scales(group, p, state) self._update_param_scales(group, p, state, P_proj)
self._diagonalize_grad_cov(group, p, state)
# We won't be doing this any more.
#self._diagonalize_grad_cov(group, p, state)
self._zero_exp_avg_sq(state) self._zero_exp_avg_sq(state)
if step % grad_cov_period == 0: if step % grad_cov_period == 0:
self._update_grad_cov(group, p, state) self._update_grad_cov(group, p, state)
@ -474,7 +475,6 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
(except batch and trivial and rank-1 dims) (except batch and trivial and rank-1 dims)
""" """
eps = group["eps"] eps = group["eps"]
param_cov_period = group["param_cov_period"]
# zero_step is always the last time we called _update_param_cov. # zero_step is always the last time we called _update_param_cov.
# Our aim is to compute the parameter covariance averaged over all time # Our aim is to compute the parameter covariance averaged over all time
@ -555,7 +555,8 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
def _update_param_scales(self, def _update_param_scales(self,
group: dict, group: dict,
p: Tensor, p: Tensor,
state: dict) -> None: state: dict,
P_proj: List[Optional[Tensor]]) -> None:
""" """
Computes learning-rate matrices Q for each dim of this tensor: only the part that depends Computes learning-rate matrices Q for each dim of this tensor: only the part that depends
on the parameter covariance, we will later add a rotation that depends on the gradient on the parameter covariance, we will later add a rotation that depends on the gradient
@ -567,6 +568,10 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
are actually factors of p, so p itself will change when we change are actually factors of p, so p itself will change when we change
them. them.
state: state dict for the current parameter state: state dict for the current parameter
P_proj: a list indexed by dim, containing tensors of shape (batch_size, size)
where size == p.shape[dim], which represent the parameter covariance
projected by state[f"Q_{dim}"] (i.e. by the value of this variable
at entry to this function)
""" """
ndim = p.ndim ndim = p.ndim
batch_size = p.shape[0] batch_size = p.shape[0]
@ -597,15 +602,7 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
assert size == 1 or size == numel, size assert size == 1 or size == numel, size
continue # e.g. size == 1 or size == numel continue # e.g. size == 1 or size == numel
param_cov = self._get_smoothed_param_cov(group, p, state, dim)
# param_cov has the same shape as Q
(batch_size, num_blocks, block_size, block_size) = Q.shape (batch_size, num_blocks, block_size, block_size) = Q.shape
# U,S,V have extra leading dimensions, i.e. of shape
# {U,V}.shape == (batch_size, num_blocks, block_size, block_size),
# S.shape == (batch_size, num_blocks, block_size).
U, S, V = _svd(param_cov)
Q[:] = U.transpose(2, 3)
M = cur_p.transpose(dim, -1) M = cur_p.transpose(dim, -1)
# if p were of shape (batch_size, x, size, y, z), # if p were of shape (batch_size, x, size, y, z),
@ -615,11 +612,12 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
num_blocks, block_size) num_blocks, block_size)
M = _move_dim(M, -2, 1) # (batch_size, num_blocks, x, y, z, block_size) M = _move_dim(M, -2, 1) # (batch_size, num_blocks, x, y, z, block_size)
while U.ndim < M.ndim: while Q.ndim < M.ndim:
U = U.unsqueeze(2) Q = Q.unsqueeze(2)
# Now U is of shape (batch_size, num_blocks, 1, 1, block_size, block_size) # Now Q is of shape (batch_size, num_blocks, 1, 1, block_size, block_size)
# [batch_index, block_index, diagonalized_coordinate, canonical_coordinate],
M = torch.matmul(M, U) # (batch_size, num_blocks, x, y, z, block_size) # so we need to transpose Q as we convert M to the diagonalized co-ordinate.
M = torch.matmul(M, Q.transpose(2, 3)) # (batch_size, num_blocks, x, y, z, block_size)
M = _move_dim(M, 1, -2) # (batch_size, x, y, z, num_blocks, block_size) M = _move_dim(M, 1, -2) # (batch_size, x, y, z, num_blocks, block_size)
M = M.reshape(*M.shape[:-2], size) # # (batch_size, x, y, z, size) M = M.reshape(*M.shape[:-2], size) # # (batch_size, x, y, z, size)
cur_p = M.transpose(dim, -1) # (batch_size, x, size, y, z) cur_p = M.transpose(dim, -1) # (batch_size, x, size, y, z)
@ -630,7 +628,8 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
cur_param_var = _mean(cur_p**2, cur_param_var = _mean(cur_p**2,
exclude_dims=[0,dim], exclude_dims=[0,dim],
keepdim=True) # (batch_size, 1, size, 1, 1) if dim==2 keepdim=True) # (batch_size, 1, size, 1, 1) if dim==2
S = S.reshape(cur_param_var.shape) # (batch_size, 1, size, 1, 1) if dim==2 smoothed_param_var = P_proj[dim] # (batch_size, size)
S = smoothed_param_var.reshape(cur_param_var.shape) # (batch_size, 1, size, 1, 1) if dim==2
# OK, cur_param_var would have the values as S if the variance stats # OK, cur_param_var would have the values as S if the variance stats
# param_cov_{dim} were accumulated from this exact parameter matrix, # param_cov_{dim} were accumulated from this exact parameter matrix,
@ -639,7 +638,7 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
# spectrum"). We scale p so that it matches the accumulated stats, # spectrum"). We scale p so that it matches the accumulated stats,
# the idea is to ensure it doesn't have any too-small eigenvalues # the idea is to ensure it doesn't have any too-small eigenvalues
# (where the stats permit). # (where the stats permit).
scale = (S.clamp(min=eps) / cur_param_var.clamp(min=eps)).sqrt() scale = (S / cur_param_var.clamp(min=eps)).sqrt()
if random.random() < 0.01: if random.random() < 0.01:
skip = 10 if size < 20 else 1 skip = 10 if size < 20 else 1
@ -684,18 +683,16 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
# rms shape: (batch_size, 1, size, 1, 1) # rms shape: (batch_size, 1, size, 1, 1)
rms = _mean(cur_p**2, exclude_dims=[0,dim], keepdim=True).sqrt() rms = _mean(cur_p**2, exclude_dims=[0,dim], keepdim=True).sqrt()
rank = numel // size rank = numel // size
# we did other kinds of smoothing in _get_smoothed_param_cov # TODO: consider more smoothing here???
#smoothed_rms = self._smooth_param_rms(group, rms, rank) cur_scales[dim] = rms
smoothed_rms = rms ** group["param_pow"] cur_p /= rms
cur_scales[dim] = smoothed_rms
cur_p /= smoothed_rms # normalize/"whiten" cur_p on this dim..
if debug: if debug:
def _summarize(rms): def _summarize(rms):
rms = rms[0] # get rid of batch dim by selecting one example rms = rms[0] # get rid of batch dim by selecting one example
rms = rms.flatten() rms = rms.flatten()
return rms[::10] # subset one every ten items return rms[::10] # subset one every ten items
logging.info(f"i={i} shape={tuple(p.shape)}, dim={dim}, rank={rank}, size={size}, rms={_summarize(rms)}, smoothed_rms={_summarize(smoothed_rms)}") logging.info(f"i={i} shape={tuple(p.shape)}, dim={dim}, rank={rank}, size={size}, rms={_summarize(rms)}")
# Apply the scales in `cur_scales` to Q for each dim; this reflects the # Apply the scales in `cur_scales` to Q for each dim; this reflects the
# parameter rms values in the parameter-diagonalized space, that we have # parameter rms values in the parameter-diagonalized space, that we have
@ -731,92 +728,272 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
state["last_param_scale_update"] = state["step"] state["last_param_scale_update"] = state["step"]
def _get_smoothed_param_cov(self, def _compute_bases(self,
group: dict, group: dict,
p: Tensor, p_shape: torch.Size,
state: dict, state: dict) -> List[Optional[Tensor]]:
dim: int) -> Tensor: """
For each tensor dimension that we are preconditioning, this function sets
state[f"Q_{dim}"] to an orthogonal matrix (per batch element and
block); it will be indexed [batch_index, block_index, diagonalized_coordinate, canonical_coordinate].
This will be the matrix that diagonalizes Z,
where Z is the symmetric matrix satisfying:
Z G Z = P (eqn:1)
with G being the gradient covariance grad_cov, P being the smoothed version of the parameter
covariance param_cov, and Z the symmetric positive definit learning-rate matrices.
We'll discuss later how we solve for Z. Firstly, because we want
a basis that is as meaningful possible for smoothing P, we will first put P and G in a basis
where G is diagonal. That is, we do an SVD G = U G' U^T, and multiplying (eqn:1) on the
left and right by U^T and U respetively, and inserting some expressions equivalent to I,
we have:
(U^T Z U) (U^T G U) (U^T Z U) = (U^T P U)
or: Z' G' Z' = P' (eqn:2)
where Z' = U^T Z U and P' = U^T P U, and of course G' is diagonal. We are actually going to
be solving (eqn:2), and then computing:
Z = U Z' U^T.
A solution to (eqn:1) is as follows. We are going to be using a Cholesky-based solution in
favor of one that requires SVD or eigenvalue decomposition, because it is much faster (we first
have to be careful that the input is not close to singular, though).
So, with the Cholesky decomposition of P' being:
P' = C C^T,
the solution is going to be:
Z' = C (C^T G' C)^{-0.5} C^T (eqn:3)
[[ we can verify that this is a solution by multiplying (eqn:1) by C^{-1} and C^{-T} on the
left and right respectively, giving:
(C^T G' C)^{-0.5} C^T G' C (C^T G' C)^{-0.5} = I
which we can immediately see is the case. ]]
Args:
group: dict to look up config values
p_shape: the shape of the batch of identical-sized tensors we are optimizing
state: dict to look up optimization state
Return:
This function returns a list of Tensors P_proj indexed by dim from 0..ndim-1, the tensors being
of shape (batch_size, size), which contain the diagonal of the smoothed parameter covariance
P after projection by the orthogonal matrix state[f"Q_{dim}"] that diagonalizes Z (this is
the space in which we do the actual update).
"""
ndim = len(p_shape)
P_proj = [None] * ndim
for dim in range(1, ndim):
size = p_shape[dim]
try:
Q = state[f"Q_{dim}"]
except KeyError:
assert size == 1 or size == numel, size
continue # e.g. size == 1 or size == numel
param_cov = state[f"param_cov_{dim}"] # (batch_size, num_blocks, block_size, block_size)
grad_cov = state[f"grad_cov_{dim}"] # (batch_size, num_blocks, block_size, block_size)
(batch_size, num_blocks, block_size, block_size) = param_cov.shape
# U_g diagonalizes grad_cov, in the sense that U_g^T grad_cov U_g is diagonal; its shape
# is the same as param_cov and grad_cov.
#
# G_prime is the diagonalized grad_cov, G' above, of shape (batch_size, num_blocks, block_size).
U_g, G_prime, _ = _svd(grad_cov)
# P_prime is P' above, which represents param_cov in the basis that diagonalizes G_prime.
# It is not smoothed yet.
P_prime = torch.matmul(U_g.transpose(2, 3), torch.matmul(param_cov, U_g))
P_prime = self._smooth_param_cov(group, p_shape, P_prime, G_prime)
C = P_prime.cholesky() # P_prime = torch.matmul(C, C.transpose(2, 3))
# CGC = (C^T G' C) which would be torch.matmul(C.transpose(2, 3), torch.matmul(G, C))
# if G were formatted as a diagonal matrix, but G is just the diagonal.
CGC = torch.matmul(C.transpose(2, 3) * G_prime.unsqueeze(-2),
C)
U, S, _ = _svd(CGC) # Need SVD to compute CGC^{-0.5}
# next we compute (eqn:3). The thing in the parenthesis is, GCC^{-0.5},
# can be written as U S^{-0.5} U^T, so the whole thing is
# (C U) S^{-0.5} (C U)^T
CU = torch.matmul(C, U)
S_inv_sqrt = 1.0 / S.sqrt()
Z_prime = torch.matmul(CU * S_inv_sqrt.unsqueeze(-2),
CU.transpose(2, 3))
if True:
# A check.
# Z_prime is supposed to satisfy: Z_prime G_prime Z_prime = P_prime (eqn:2)
P_prime_check = torch.matmul(Z_prime * G_prime.unsqueeze(-2), Z_prime)
diff_ratio = (P_prime - P_prime_check).abs().sum() / P_prime.abs().sum()
if diff_ratio > 0.01:
logging.warn(f"Z_prime does not satisfy its definition, diff_ratio = {diff_ratio}")
Z = torch.matmul(U_g, torch.matmul(Z_prime, U_g.transpose(2, 3)))
# OK, Z is the SPD transform that maps G to P, as in Z G Z = P.
# We just need the basis that diagonalizes this.
U_z, S, _ = _svd(Z)
if True:
skip = 10 if S.shape[-1] > 40 else 1
logging.info(f"Eigs of Z are: {S[0,0,::skip]}")
# state[f"Q_{dim}"] is indexed: [batch_idx, block_idx, diagonalized_coordinate, canonical_coordinate].
# so we need to transpose U_z as U_z is indexed
# [batch_idx, block_idx, canonical_coordinate, diagonalized_coordinate]
Q[:] = U_z.transpose(2, 3)
# Work out the projected smoothed parameter covariance P_proj, which is P
# projected in the basis U_z. Now,
# P = U_g P' U_g^T,
# and P_proj = U_z^T P U_z,
# so P_proj = (U_z^T U_g) P' (U_z^T U_g)^T
U_prod = torch.matmul(U_z.transpose(2, 3), U_g)
# this_P_proj shape: (batch_size, num_blocks, block_size)
this_P_proj = _diag(torch.matmul(U_prod, torch.matmul(P_prime, U_prod.transpose(2, 3))))
P_proj[dim] = this_P_proj.reshape(batch_size, size)
return P_proj
def _smooth_param_cov(self,
group: dict,
p_shape: torch.Size,
P_prime: Tensor,
G_prime: Tensor) -> Tensor:
""" """
This function returns a modified/smoothed version of the parameter covariance This function returns a modified/smoothed version of the parameter covariance
P_prime.
Args:
group: dict to look up config values
p_shape: The shape of the parameter we are optimizing
P_prime: a Tensor of shape (batch_size, num_blocks, block_size, block_size),
containing the parameter covariance in a basis that diagonalizes the
gradient covariance.
G_prime: the diagonalized gradient covariance, of shape (batch_size, num_blocks,
block_size)
state[f"param_cov_{dim}"], which is an estimate of the covariance of the parameter state[f"param_cov_{dim}"], which is an estimate of the covariance of the parameter
p, averaged over time, and taken over dimension `dim` of the tensor. p, averaged over time, and taken over dimension `dim` of the tensor.
The smoothing done here limits the extend to which the parameter covariance The smoothing done here limits the extent to which the parameter covariance
can be strongly "off-diagonal" with respect to the gradient covariance. That is: can be strongly "off-diagonal" with respect to the gradient covariance. That is:
if the parameter covariance is just the gradient covariance to some power, this if the parameter covariance is just the gradient covariance to some power, this
function does no smoothing; but if it is highly off-diagonal we do more smoothing. function does no smoothing; but if it is highly off-diagonal we do more smoothing.
""" """
param_cov = state[f"param_cov_{dim}"] # (batch_size, num_blocks, block_size, block_size) P_prime_diag = _diag(P_prime) # (batch_size, num_blocks, block_size)
grad_cov = state[f"grad_cov_{dim}"] # (batch_size, num_blocks, block_size, block_size) eps = 1.0e-10
(batch_size, num_blocks, block_size, block_size) = param_cov.shape P_prime_diag = (P_prime_diag + eps) / P_prime_diag.mean()
U_g, _, _ = _svd(grad_cov) # U_g diagonalizes grad_cov, in the sense that U_g^T grad_cov U_g is diagonal. # make sure no diagonal element is close to zero.. we don't expect this
# would happen. this is likely not important. Note, this just used for
# normalizing P prior to smoothing.
P_prime_diag.clamp_(min=0.01)
P_prime_rms = P_prime_diag.sqrt()
P_prime_scale = P_prime_rms.unsqueeze(-1) * P_prime_rms.unsqueeze(-2)
# param_cov_proj is param_cov in a different orthonormal basis, that diagonalizes # P_norm will have diagonal elements close to 1. We do some smoothing
# grad_cov. # in this space.
param_cov_proj = torch.matmul(U_g.transpose(2, 3), torch.matmul(param_cov, U_g)) P_norm = P_prime / P_prime_scale
# Now P is as normalized as we can make it... do smoothing baserd on 'rank',
# param_cov_eps is probably not critical, I don't expect to see super # that is intended to compensate for bad estimates of P.
# small values. apply as floor in case roundoff causes negative values. batch_size = p_shape[0]
param_cov_eps = 1.0e-05 size = P_prime.shape[0] # size of dim we are concerned with right now
param_rms = _diag(param_cov_proj).clamp_(min=param_cov_eps).sqrt() # `rank` is the rank of P_prime if we were to estimate it from just one
param_cov_inv_scale = param_rms.unsqueeze(-1) * param_rms.unsqueeze(-2) # parameter tensor. We average it over time, but actually it won't be changing
# too much, so `rank` does tell us something.
# param_cov_norm should have diagonal values close to 1.0 (only not rank = p_shape.numel() // (size * batch_size)
# exactly 1.0 due to param_cov_eps and roundoff)
param_cov_norm = param_cov_proj / param_cov_inv_scale
# OK, this is where we do smoothing.
# decompose param_cov_norm, which is symmetric, as U_p S U_p^T
U_p, S, _ = _svd(param_cov_norm)
residual_rms = S.sqrt()
#
relative_rms_pow = 0.7
relative_rms_max = 4.0
residual_rms = residual_rms ** relative_rms_pow
residual_rms /= _mean(residual_rms, exclude_dims=[0], keepdim=True)
if True:
# smooth according to the rank of the observation..
size = p.shape[dim]
rank = p.numel() // (size * batch_size)
smooth0 = group["param_rms_smooth0"] smooth0 = group["param_rms_smooth0"]
smooth1 = group["param_rms_smooth1"] smooth1 = group["param_rms_smooth1"]
# want expr to be of the form: smooth = alpha * size / (beta*rank + size) # We want expr for smoothing amount to be of the form: smooth = alpha * size / (beta*rank + size)
# from rank==0, we get smooth0 = alpha * size/size, so alpha = smooth0. # param_rms_smooth{0,1} represents the user-specified desired amount of smoothing
# when rank==0*size and rank==1*size, respectively.
# from rank==0*size, we get smooth0 = alpha * size/size, so alpha = smooth0.
# from setting rank==size, we get smooth1 = alpha * size / (beta*size * size) = alpha/(1+beta), # from setting rank==size, we get smooth1 = alpha * size / (beta*size * size) = alpha/(1+beta),
# so smooth1 == smooth0 / (1+beta), so (1+beta) = smooth0/smooth1, so beta=smooth0/smooth1 - 1 # so smooth1 == smooth0 / (1+beta), so (1+beta) = smooth0/smooth1, so beta=smooth0/smooth1 - 1
smooth = smooth0 * size / ((smooth0/smooth1 - 1) * rank + size) smooth = smooth0 * size / ((smooth0/smooth1 - 1) * rank + size)
mean = _mean(residual_rms, exclude_dims=[0], keepdim=True) # add rank-dependent smoothing amount to diagonal of P_prime. _diag() returns an aliased tensor.
residual_rms += group["eps"] + smooth * mean # we don't need to multiply `smooth` by anything, because at this point, P_prime should have
residual_rms = residual_rms / _mean(residual_rms, exclude_dims=[0], keepdim=True) # diagonal elements close to 1.
_diag(P_prime).add_(smooth)
P_norm = self._smooth_cov(P_norm,
group["min_lr_factor"][0],
group["max_lr_factor"][0])
# Remove the diagonal preconditioning on P_norm, giving us stage-1-smoothed
# version of P_prime.
P_prime = P_norm * P_prime_scale
# apply the maximum via a softmin function, softmin(x,y) = 1/(1/x + 1/y) # Make sure G_prime has unit mean and no eigenvalue is super small. Note, G_prime
residual_rms = 1. / (1. / residual_rms + 1. / relative_rms_max) # is already diagonal.
G_prime_mean = _mean(G_prime, exclude_dims=[0], keepdim=True)
G_prime_smooth = 0.001
# make sure G_prime has no zero eigs, and is unit mean.
G_prime = ((G_prime + eps + G_prime_smooth * G_prime_mean) /
(G_prime_mean * (1+G_prime_smooth) + eps))
G_prime_rms = G_prime.sqrt()
G_prime_scale = G_prime_rms.unsqueeze(-1) * G_prime_rms.unsqueeze(-2)
# P_gnorm is a version of P_prime that is scaled relative to G, i.e.
# scaled in such a way that would make G the unit matrix.
P_gnorm = P_prime / G_prime_scale
# Apply another round of smoothing "relative to G"
P_gnorm = self._smooth_cov(P_gnorm,
group["min_lr_factor"][1],
group["max_lr_factor"][1])
# Undo the scaling relative to G, so we have stage-2-smoothed version of P_prime.
P_prime = P_gnorm * G_prime_scale
if random.random() < 0.1: # Apply a 3rd round of smoothing
skip = 10 if S.shape[-1] > 40 else 1 P_prime = self._smooth_cov(P_prime,
logging.info(f"Smoothed param_rms from {S.sqrt()[0,0,::skip]} to {residual_rms[0,0,::skip]}, param_rms={param_rms[0,0,::skip]}") group["min_lr_factor"][2],
group["max_lr_factor"][2])
return P_prime
# U shape: (batch_size, num_blocks, block_size, block_size), def _smooth_cov(self,
# interpreted as X: Tensor,
# residual_rms shape: (batch_size, num_blocks, block_size). min_eig: float,
# so in terms of matrix multiplication, we are computing X_p = matmul(U_p, residual_rms.diag()) max_eig: float,
X_p = U_p * residual_rms.unsqueeze(-2) power: float = 1.0) -> Tensor:
param_cov_norm_smoothed = torch.matmul(X_p, X_p.transpose(2, 3)) """
Returns a `smoothed` version of a symmetric positive definite covariance matrix
[with block-diagonal structure, in a batch]. This is done without SVD (which
can be very slow).
The eigenvalues L will be transformed as:
# Undo the scaling by the diagonal of param_cov L = L + min_eig * L.mean() + eps
param_cov_proj_smoothed = param_cov_norm_smoothed * param_cov_inv_scale L /= L.mean()
L = 1 / (1/L + 1/max_eig) # soft-min between L and max_eig
L = L ** power # need SVD for this, will get rid of the requirement later.
L /= L.mean()
# Undo the projection by U. # Note on approximation functions like x^0.75 for smallish x: on wolframalpha, type:
param_cov_smoothed = torch.matmul(U_g, torch.matmul(param_cov_proj_smoothed, # plot x^0.75 and 0.05 + (1.1x - 0.18 x^2 + 0.02 x^3) for x from 0 to 10
U_g.transpose(2, 3))) # [this starts to diverge after 5 or so]
return param_cov_smoothed
Args:
min_eig: minimum allowed eigenvalue of returned X
max_eig: maximum allowed eigenvalue of returned X
power: power to take eigenvalues to
X: the batch of symmetric positive definite tensors we are smoothing;
of shape (batch_size, num_blocks, block_size, block_size)
"""
if power != 1.0:
U, S, _ = _svd(X)
eps = 1.0e-10
S_mean = _mean(S, exclude_dims=[0], keepdim=True)
S = S + min_eig * S_mean + eps
S_mean = S_mean * (1 + min_eig) + eps
S = S / S_mean
S = 1. / (1./S + 1./max_eig)
S = S ** power
S = S / _mean(S, exclude_dims=[0], keepdim=True)
return torch.matmul(U * S.unsqueeze(-2), U.transpose(2, 3))
else:
diag = _diag(X) # Aliased with X
mean_eig = _mean(diag, exclude_dims=[0], keepdim=True)
eps = 1.0e-10 # prevent division by zero
diag += (mean_eig * min_eig + eps)
cur_diag_mean = mean_eig * (1 + min_eig) + eps
# The following 2 statements will be equivalent to:
# L /= L.mean()
# L = 1 / (1/L + 1/max_eig) # soft-min between L and max_eig
# if L is the eigenvalues
X = (X.inverse() + 1/(max_eig * cur_diag_mean.unsqueeze(-1))).inverse()
X /= _mean(_diag(X), exclude_dims=[0], keepdim=True).unsqueeze(-1)
return X
def _diagonalize_grad_cov(self, def _diagonalize_grad_cov(self,
@ -1144,22 +1321,27 @@ def _move_dim(x: Tensor, orig_dim: int, new_dim: int) -> Tensor:
def _diag(x: Tensor): def _diag(x: Tensor):
""" """
like torch diag(), but supports batch dim, i.e. input of shape (B, M, M) returns like torch diag(), this returns the diagonal of a matrix; but it returns an
aliased tensor, and it supports batch dims,
i.e. input of shape (B, M, M) returns
output of shape (B, M), or input of shape (A, B, M, M) returns output of shape output of shape (B, M), or input of shape (A, B, M, M) returns output of shape
(A, B, M) (A, B, M).
""" """
stride = x.stride()
if x.ndim == 3: if x.ndim == 3:
(B, M, M2) = x.shape (B, M, M2) = x.shape
assert M == M2 assert M == M2
stride = x.stride() ans = x.as_strided(size=(B, M), stride=(stride[0], stride[1] + stride[2]))
return x.as_strided(size=(B, M), stride=(stride[0], stride[1] + stride[2])).contiguous()
elif x.ndim == 4: elif x.ndim == 4:
(B, C, M, M2) = x.shape (B, C, M, M2) = x.shape
assert M == M2 assert M == M2
stride = x.stride() ans = x.as_strided(size=(B, C, M), stride=(stride[0], stride[1], stride[2] + stride[3])).contiguous()
return x.as_strided(size=(B, C, M), stride=(stride[0], stride[1], stride[2] + stride[3])).contiguous() elif x.ndim == 2:
else: (M, M2) = x.shape
return x.diag() assert M == M2
ans = x.as_strided(size=(M,), stride=(stride[0] + stride[1],)).contiguous()
return ans
def _sum(x: Tensor, def _sum(x: Tensor,
exclude_dims: List[int] = [], exclude_dims: List[int] = [],