Refactoring, putting tunable values in constructor, a little cleanup

This commit is contained in:
Daniel Povey 2022-07-25 04:31:42 +08:00
parent 8efc512823
commit 06718052ec

View File

@ -107,9 +107,6 @@ class PrAdam(BatchedOptimizer):
scale of each parameter tensor. If each parameter were decomposed
as p * p_scale.exp(), where (p**2).mean().sqrt() == 1.0, size_lr_scale
is the scaling factor on the learning rate of p_scale.
param_pow: Power on the parameter covariance matrix, 1.0 means learn proportional
to parameter rms (1.0 will be too much, should be between 0 and 1.)
This is one of the most important tunable factors, along with param_cov_max.
param_rms_smooth0: Limiting value of smoothing proportion for parameter matrix, as
assumed rank of param covariance [==product of sizes on the other
tensor dims] approaches 0.
@ -117,18 +114,21 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
param covariance equals the dimension of the covaraince matrix.
param_rms_smooth{0,1} determine the smoothing proportions for other
conditions.
param_cov_min: [IMPORTANT] A 3-tuple of minimums of the diagonal values of the parameter
covariance, normalized in 3 different ways: relative to its own
diagonal, scaled by the grad covariance, and in the canonical basis.
With param_cov_max, defines how "aggressive" we allow our update to
be.
param_cov_max: [IMPORTANT] A 3-tuple of maximums of the diagonal values of the parameter
covariance, normalized in 3 different ways: relative to its own
diagonal, scaled by the grad covariance, and in the canonical basis.
param_pow: This was mainly added for development and experimentation purposes;
it allows you to smooth the parameter covariance by taking it to
a power, but if this is not 1.0 it will cause a speed penalty because
it requires SVD.
cov_min,cov_max: [IMPORTANT] 4-tuples of minimums and maximums of the diagonal values of
covariance matrices, after normalizing to unit-mean. The first 3 are
for smoothing the parameter covariance, normalized in 3 different ways:
(1) relative to its own diagonal (in a basis that diagonalizes the grad
covariance);
(2) multiplied by the grad covariance,
(3) in the canonical basis.
(4) is for smoothing the grad covariance used for (2)
cov_pow: This was mainly added for development and experimentation purposes;
it allows you to smooth the parameter covariance matrices at the
stages (1), (2), (3) of smoothing mentioned above, and also
the gradient covariance matrix used in stage (2) of smoothing. If the
1st 3 values are not 1.0 it will cause a measurable speed penalty because it
requires SVD. Recommend to leave all these at 1.0.
eps: A general-purpose epsilon to prevent division by zero
param_min_rms: Minimum root-mean-square value of parameter tensor, for purposes of
learning the scale on the parameters (we'll constrain the rms of each non-scalar
@ -159,9 +159,9 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
lr=3e-02,
betas=(0.9, 0.98),
size_lr_scale=0.1,
param_cov_min=(0.05, 0.01, 0.04),
param_cov_max=(10.0, 40.0, 5.0),
param_pow=(1.0, 1.0, 1.0),
cov_min=(0.05, 0.01, 0.04, 0.0001),
cov_max=(10.0, 40.0, 5.0, 400.0),
cov_pow=(1.0, 1.0, 1.0, 1.0),
param_rms_smooth0=0.4,
param_rms_smooth1=0.2,
eps=1.0e-08,
@ -182,9 +182,9 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
defaults = dict(
lr=lr,
size_lr_scale=size_lr_scale,
param_cov_min=param_cov_min,
param_cov_max=param_cov_max,
param_pow=param_pow,
cov_min=cov_min,
cov_max=cov_max,
cov_pow=cov_pow,
param_rms_smooth0=param_rms_smooth0,
param_rms_smooth1=param_rms_smooth1,
betas=betas,
@ -289,7 +289,12 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
p, memory_format=torch.preserve_format
)
ignore_rank1_dims = True # a config value, can change this (TODO: tune)
# ignore_rank1_dims = True means we don't do our basis-changing update
# on tensors like biases that only have one non-trivial dimension.
# "rank1" refers to the rank of our estimate of the parameter
# covariance, if we were to estimate it just from the current parameter
# value.
ignore_rank1_dims = True
trivial_update = True
for dim in range(1, p.ndim):
@ -325,7 +330,7 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
Q = torch.eye(block_size, block_size, **kwargs).unsqueeze(0).unsqueeze(0).expand(
batch_size, num_blocks, block_size, block_size).contiguous()
state[f"Q_{dim}"] = Q
# param_cov_{dim} is the averaged-over-time covariance of parameters on this dimension, treating
# cov_{dim} is the averaged-over-time covariance of parameters on this dimension, treating
# all other dims as a batch axis. Also initialize as identity.
state[f"param_cov_{dim}"] = Q.clone()
@ -1046,23 +1051,28 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
_diag(P_norm).add_(smooth)
P_norm = self._smooth_cov(P_norm,
group["param_cov_min"][0],
group["param_cov_max"][0],
group["param_pow"][0])
group["cov_min"][0],
group["cov_max"][0],
group["cov_pow"][0])
# Remove the diagonal preconditioning on P_norm, giving us stage-1-smoothed
# version of P_prime.
P_prime = P_norm * P_prime_scale
# Make sure G_prime has unit mean and no eigenvalue is super small. Note, G_prime
# is already diagonalized, the variable G_prime is just the tensor of eigenvalues.
G_prime_mean = _mean(G_prime, exclude_dims=[0], keepdim=True)
G_prime_smooth = 0.0001
# make sure G_prime has no zero eigs, and is unit mean.
G_prime = ((G_prime + eps + G_prime_smooth * G_prime_mean) /
(G_prime_mean * (1+G_prime_smooth) + eps))
# it now has unit mean..
G_prime_max = 400.0
G_prime = 1. / (1./G_prime + 1./G_prime_max) # apply max
if True:
# This block smooths G_prime.
# Make sure G_prime has unit mean and no eigenvalue is super small. Note, G_prime
# is already diagonalized, the variable G_prime is just the tensor of eigenvalues.
G_prime_mean = _mean(G_prime, exclude_dims=[0], keepdim=True)
G_prime_min = group["cov_min"][3]
# make sure G_prime has no zero eigs, and is unit mean.
G_prime = ((G_prime + eps + G_prime_min * G_prime_mean) /
(G_prime_mean * (1+G_prime_min) + eps))
# it now has unit mean..
G_prime_max = group["cov_max"][3]
G_prime = 1. / (1./G_prime + 1./G_prime_max) # apply max
G_prime_pow = group["cov_pow"][3]
if G_prime_pow != 1.0:
G_prime = G_prime ** G_prime_pow
G_prime_rms = G_prime.sqrt()
G_prime_scale = G_prime_rms.unsqueeze(-1) * G_prime_rms.unsqueeze(-2)
@ -1073,17 +1083,17 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
P_gnorm = P_prime * G_prime_scale
# Apply another round of smoothing "relative to G"
P_gnorm = self._smooth_cov(P_gnorm,
group["param_cov_min"][1],
group["param_cov_max"][1],
group["param_pow"][1])
group["cov_min"][1],
group["cov_max"][1],
group["cov_pow"][1])
# Undo the scaling relative to G, so we have stage-2-smoothed version of P_prime.
P_prime = P_gnorm / G_prime_scale
# Apply a 3rd round of smoothing in the canonical basis.
P_prime = self._smooth_cov(P_prime,
group["param_cov_min"][2],
group["param_cov_max"][2],
group["param_pow"][2])
group["cov_min"][2],
group["cov_max"][2],
group["cov_pow"][2])
return P_prime
def _smooth_cov(self,
@ -1378,66 +1388,15 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
while Q.ndim < M.ndim:
Q = Q.unsqueeze(2)
# now Q has shape (batch_size, num_blocks, 1, 1, block_size, block_size)
M = torch.matmul(M, Q) # (batch_size, num_blocks, x, y, z, block_size)
if M.ndim < Q.ndim: # special case where M shape is e.g. (batch_size, num_blocks, x)
M = torch.matmul(M.unsqueeze(-2), Q).squeeze(-2)
else:
M = torch.matmul(M, Q) # (batch_size, num_blocks, x, y, z, block_size)
M = _move_dim(M, 1, -2) # (batch_size, x, y, z, num_blocks, block_size)
M = M.reshape(*M.shape[:-2], size) # (batch_size, x, y, z, size)
M = M.transpose(-1, dim) # (batch_size, x, size, y, z)
return M
def _smooth_param_rms(self,
group: dict,
rms: Tensor,
rank: int) -> Tensor:
"""
Smooths and normalizes to mean=1 a tensor of shape something like (batch_size, 1, size, 1, 1),
where for the nontrivial dim `size` it contains dagonalized parameter rms values; this is for
one dimension of a multiple-dimensional tensor.
This will be used to construct a learning-rate matrix.
Args:
group: dict for configuration values
rms: A tensor of shape (batch_size, [1,1,..], size, [1,1,1,..]), representing
(for each batch element) a list of root-mean-square values, one per
dimension of the space of size `size`.
rank: the assumed rank of the covariance matrix from which rms was derived,
used to decide how much smoothing to apply. This is actually the
minimal rank, if the parameter matrix were stay fixed during training,
but still relevant to know how robust the parameter covariance estimate is.
Returns:
a Tensor with the same shape as `rms` but with some smoothing applied so there
are no values too close to zero.
"""
smooth0 = group["param_rms_smooth0"]
smooth1 = group["param_rms_smooth1"]
param_cov_max = group["param_cov_max"]
param_pow = group["param_pow"]
eps = group["eps"]
batch_size = rms.shape[0]
size = rms.numel() // batch_size
rms = rms ** param_pow
# want expr to be of the form: smooth = alpha * size / (beta*rank + size)
# from rank==0, we get smooth0 = alpha * size/size, so alpha = smooth0.
# from setting rank==size, we get smooth1 = alpha * size / (beta*size * size) = alpha/(1+beta),
# so smooth1 == smooth0 / (1+beta), so (1+beta) = smooth0/smooth1, so beta=smooth0/smooth1 - 1
smooth = smooth0 * size / ((smooth0/smooth1 - 1) * rank + size)
mean = _mean(rms, exclude_dims=[0], keepdim=True)
rms += eps + smooth * mean
new_mean = (eps + (smooth + 1) * mean) # mean of modified rms.
ans = rms / new_mean
# Apply a `soft min` of param_cov_max via the formula
# softmin(x,y) = 1/(1/x + 1/y).
ans = 1. / (1. / ans + 1. / param_cov_max)
# and renormalize to mean=1.
ans /= _mean(ans, exclude_dims=[0], keepdim=True)
return ans
def _move_dim(x: Tensor, orig_dim: int, new_dim: int) -> Tensor:
"""
@ -2197,9 +2156,8 @@ def _test_eve_cain():
for iter in [3, 2]:
fix_random_seed(42)
Linear = torch.nn.Linear if iter == 0 else ScaledLinear
# TODO: find out why this is not converging...
hidden_dim = 512
hidden_dim = 768
m = torch.nn.Sequential(Linear(E, hidden_dim),
torch.nn.PReLU(),
Linear(hidden_dim, hidden_dim),