Draft of new way of smoothing param_rms, diagonalized by grad

This commit is contained in:
Daniel Povey 2022-07-22 06:37:20 +08:00
parent a63afe348a
commit 4da4e69fba

View File

@ -142,10 +142,10 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
lr=3e-02,
betas=(0.9, 0.98),
size_lr_scale=0.1,
param_pow=0.75,
param_pow=1.0,
param_rms_smooth0=0.75,
param_rms_smooth1=0.25,
max_lr_factor=4.0,
max_lr_factor=10.0,
eps=1.0e-08,
param_min_rms=1.0e-05,
param_max_rms=2.0,
@ -593,10 +593,11 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
size = p.shape[dim]
try:
Q = state[f"Q_{dim}"]
param_cov = state[f"param_cov_{dim}"]
except KeyError:
assert size == 1 or size == numel, size
continue # e.g. size == 1 or size == numel:
continue # e.g. size == 1 or size == numel
param_cov = self._get_smoothed_param_cov(group, p, state, dim)
# param_cov has the same shape as Q
(batch_size, num_blocks, block_size, block_size) = Q.shape
@ -683,7 +684,9 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
# rms shape: (batch_size, 1, size, 1, 1)
rms = _mean(cur_p**2, exclude_dims=[0,dim], keepdim=True).sqrt()
rank = numel // size
smoothed_rms = self._smooth_param_rms(group, rms, rank)
# we did other kinds of smoothing in _get_smoothed_param_cov
#smoothed_rms = self._smooth_param_rms(group, rms, rank)
smoothed_rms = rms ** group["param_pow"]
cur_scales[dim] = smoothed_rms
cur_p /= smoothed_rms # normalize/"whiten" cur_p on this dim..
@ -727,6 +730,95 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
state[f"Q_{dim}"] *= scale
state["last_param_scale_update"] = state["step"]
def _get_smoothed_param_cov(self,
group: dict,
p: Tensor,
state: dict,
dim: int) -> Tensor:
"""
This function returns a modified/smoothed version of the parameter covariance
state[f"param_cov_{dim}"], which is an estimate of the covariance of the parameter
p, averaged over time, and taken over dimension `dim` of the tensor.
The smoothing done here limits the extend to which the parameter covariance
can be strongly "off-diagonal" with respect to the gradient covariance. That is:
if the parameter covariance is just the gradient covariance to some power, this
function does no smoothing; but if it is highly off-diagonal we do more smoothing.
"""
param_cov = state[f"param_cov_{dim}"] # (batch_size, num_blocks, block_size, block_size)
grad_cov = state[f"grad_cov_{dim}"] # (batch_size, num_blocks, block_size, block_size)
(batch_size, num_blocks, block_size, block_size) = param_cov.shape
U_g, _, _ = _svd(grad_cov) # U_g diagonalizes grad_cov, in the sense that U_g^T grad_cov U_g is diagonal.
# param_cov_proj is param_cov in a different orthonormal basis, that diagonalizes
# grad_cov.
param_cov_proj = torch.matmul(U_g.transpose(2, 3), torch.matmul(param_cov, U_g))
# param_cov_eps is probably not critical, I don't expect to see super
# small values. apply as floor in case roundoff causes negative values.
param_cov_eps = 1.0e-05
param_rms = _diag(param_cov_proj).clamp_(min=param_cov_eps).sqrt()
param_cov_inv_scale = param_rms.unsqueeze(-1) * param_rms.unsqueeze(-2)
# param_cov_norm should have diagonal values close to 1.0 (only not
# exactly 1.0 due to param_cov_eps and roundoff)
param_cov_norm = param_cov_proj / param_cov_inv_scale
# OK, this is where we do smoothing.
# decompose param_cov_norm, which is symmetric, as U_p S U_p^T
U_p, S, _ = _svd(param_cov_norm)
residual_rms = S.sqrt()
#
relative_rms_pow = 0.7
relative_rms_max = 4.0
residual_rms = residual_rms ** relative_rms_pow
residual_rms /= _mean(residual_rms, exclude_dims=[0], keepdim=True)
if True:
# smooth according to the rank of the observation..
size = p.shape[dim]
rank = p.numel() // (size * batch_size)
smooth0 = group["param_rms_smooth0"]
smooth1 = group["param_rms_smooth1"]
# want expr to be of the form: smooth = alpha * size / (beta*rank + size)
# from rank==0, we get smooth0 = alpha * size/size, so alpha = smooth0.
# from setting rank==size, we get smooth1 = alpha * size / (beta*size * size) = alpha/(1+beta),
# so smooth1 == smooth0 / (1+beta), so (1+beta) = smooth0/smooth1, so beta=smooth0/smooth1 - 1
smooth = smooth0 * size / ((smooth0/smooth1 - 1) * rank + size)
mean = _mean(residual_rms, exclude_dims=[0], keepdim=True)
residual_rms += group["eps"] + smooth * mean
residual_rms = residual_rms / _mean(residual_rms, exclude_dims=[0], keepdim=True)
# apply the maximum via a softmin function, softmin(x,y) = 1/(1/x + 1/y)
residual_rms = 1. / (1. / residual_rms + 1. / relative_rms_max)
if random.random() < 0.1:
skip = 10 if S.shape[-1] > 40 else 1
logging.info(f"Smoothed param_rms from {S.sqrt()[0,0,::skip]} to {residual_rms[0,0,::skip]}, param_rms={param_rms[0,0,::skip]}")
# U shape: (batch_size, num_blocks, block_size, block_size),
# interpreted as
# residual_rms shape: (batch_size, num_blocks, block_size).
# so in terms of matrix multiplication, we are computing X_p = matmul(U_p, residual_rms.diag())
X_p = U_p * residual_rms.unsqueeze(-2)
param_cov_norm_smoothed = torch.matmul(X_p, X_p.transpose(2, 3))
# Undo the scaling by the diagonal of param_cov
param_cov_proj_smoothed = param_cov_norm_smoothed * param_cov_inv_scale
# Undo the projection by U.
param_cov_smoothed = torch.matmul(U_g, torch.matmul(param_cov_proj_smoothed,
U_g.transpose(2, 3)))
return param_cov_smoothed
def _diagonalize_grad_cov(self,
group: dict,
p: Tensor,