Refactoring, putting tunable values in constructor, a little cleanup

This commit is contained in:
Daniel Povey 2022-07-25 04:31:42 +08:00
parent 8efc512823
commit 06718052ec

View File

@ -107,9 +107,6 @@ class PrAdam(BatchedOptimizer):
scale of each parameter tensor. If each parameter were decomposed scale of each parameter tensor. If each parameter were decomposed
as p * p_scale.exp(), where (p**2).mean().sqrt() == 1.0, size_lr_scale as p * p_scale.exp(), where (p**2).mean().sqrt() == 1.0, size_lr_scale
is the scaling factor on the learning rate of p_scale. is the scaling factor on the learning rate of p_scale.
param_pow: Power on the parameter covariance matrix, 1.0 means learn proportional
to parameter rms (1.0 will be too much, should be between 0 and 1.)
This is one of the most important tunable factors, along with param_cov_max.
param_rms_smooth0: Limiting value of smoothing proportion for parameter matrix, as param_rms_smooth0: Limiting value of smoothing proportion for parameter matrix, as
assumed rank of param covariance [==product of sizes on the other assumed rank of param covariance [==product of sizes on the other
tensor dims] approaches 0. tensor dims] approaches 0.
@ -117,18 +114,21 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
param covariance equals the dimension of the covaraince matrix. param covariance equals the dimension of the covaraince matrix.
param_rms_smooth{0,1} determine the smoothing proportions for other param_rms_smooth{0,1} determine the smoothing proportions for other
conditions. conditions.
param_cov_min: [IMPORTANT] A 3-tuple of minimums of the diagonal values of the parameter cov_min,cov_max: [IMPORTANT] 4-tuples of minimums and maximums of the diagonal values of
covariance, normalized in 3 different ways: relative to its own covariance matrices, after normalizing to unit-mean. The first 3 are
diagonal, scaled by the grad covariance, and in the canonical basis. for smoothing the parameter covariance, normalized in 3 different ways:
With param_cov_max, defines how "aggressive" we allow our update to (1) relative to its own diagonal (in a basis that diagonalizes the grad
be. covariance);
param_cov_max: [IMPORTANT] A 3-tuple of maximums of the diagonal values of the parameter (2) multiplied by the grad covariance,
covariance, normalized in 3 different ways: relative to its own (3) in the canonical basis.
diagonal, scaled by the grad covariance, and in the canonical basis.
param_pow: This was mainly added for development and experimentation purposes; (4) is for smoothing the grad covariance used for (2)
it allows you to smooth the parameter covariance by taking it to cov_pow: This was mainly added for development and experimentation purposes;
a power, but if this is not 1.0 it will cause a speed penalty because it allows you to smooth the parameter covariance matrices at the
it requires SVD. stages (1), (2), (3) of smoothing mentioned above, and also
the gradient covariance matrix used in stage (2) of smoothing. If the
1st 3 values are not 1.0 it will cause a measurable speed penalty because it
requires SVD. Recommend to leave all these at 1.0.
eps: A general-purpose epsilon to prevent division by zero eps: A general-purpose epsilon to prevent division by zero
param_min_rms: Minimum root-mean-square value of parameter tensor, for purposes of param_min_rms: Minimum root-mean-square value of parameter tensor, for purposes of
learning the scale on the parameters (we'll constrain the rms of each non-scalar learning the scale on the parameters (we'll constrain the rms of each non-scalar
@ -159,9 +159,9 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
lr=3e-02, lr=3e-02,
betas=(0.9, 0.98), betas=(0.9, 0.98),
size_lr_scale=0.1, size_lr_scale=0.1,
param_cov_min=(0.05, 0.01, 0.04), cov_min=(0.05, 0.01, 0.04, 0.0001),
param_cov_max=(10.0, 40.0, 5.0), cov_max=(10.0, 40.0, 5.0, 400.0),
param_pow=(1.0, 1.0, 1.0), cov_pow=(1.0, 1.0, 1.0, 1.0),
param_rms_smooth0=0.4, param_rms_smooth0=0.4,
param_rms_smooth1=0.2, param_rms_smooth1=0.2,
eps=1.0e-08, eps=1.0e-08,
@ -182,9 +182,9 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
defaults = dict( defaults = dict(
lr=lr, lr=lr,
size_lr_scale=size_lr_scale, size_lr_scale=size_lr_scale,
param_cov_min=param_cov_min, cov_min=cov_min,
param_cov_max=param_cov_max, cov_max=cov_max,
param_pow=param_pow, cov_pow=cov_pow,
param_rms_smooth0=param_rms_smooth0, param_rms_smooth0=param_rms_smooth0,
param_rms_smooth1=param_rms_smooth1, param_rms_smooth1=param_rms_smooth1,
betas=betas, betas=betas,
@ -289,7 +289,12 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
p, memory_format=torch.preserve_format p, memory_format=torch.preserve_format
) )
ignore_rank1_dims = True # a config value, can change this (TODO: tune) # ignore_rank1_dims = True means we don't do our basis-changing update
# on tensors like biases that only have one non-trivial dimension.
# "rank1" refers to the rank of our estimate of the parameter
# covariance, if we were to estimate it just from the current parameter
# value.
ignore_rank1_dims = True
trivial_update = True trivial_update = True
for dim in range(1, p.ndim): for dim in range(1, p.ndim):
@ -325,7 +330,7 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
Q = torch.eye(block_size, block_size, **kwargs).unsqueeze(0).unsqueeze(0).expand( Q = torch.eye(block_size, block_size, **kwargs).unsqueeze(0).unsqueeze(0).expand(
batch_size, num_blocks, block_size, block_size).contiguous() batch_size, num_blocks, block_size, block_size).contiguous()
state[f"Q_{dim}"] = Q state[f"Q_{dim}"] = Q
# param_cov_{dim} is the averaged-over-time covariance of parameters on this dimension, treating # cov_{dim} is the averaged-over-time covariance of parameters on this dimension, treating
# all other dims as a batch axis. Also initialize as identity. # all other dims as a batch axis. Also initialize as identity.
state[f"param_cov_{dim}"] = Q.clone() state[f"param_cov_{dim}"] = Q.clone()
@ -1046,23 +1051,28 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
_diag(P_norm).add_(smooth) _diag(P_norm).add_(smooth)
P_norm = self._smooth_cov(P_norm, P_norm = self._smooth_cov(P_norm,
group["param_cov_min"][0], group["cov_min"][0],
group["param_cov_max"][0], group["cov_max"][0],
group["param_pow"][0]) group["cov_pow"][0])
# Remove the diagonal preconditioning on P_norm, giving us stage-1-smoothed # Remove the diagonal preconditioning on P_norm, giving us stage-1-smoothed
# version of P_prime. # version of P_prime.
P_prime = P_norm * P_prime_scale P_prime = P_norm * P_prime_scale
# Make sure G_prime has unit mean and no eigenvalue is super small. Note, G_prime if True:
# is already diagonalized, the variable G_prime is just the tensor of eigenvalues. # This block smooths G_prime.
G_prime_mean = _mean(G_prime, exclude_dims=[0], keepdim=True) # Make sure G_prime has unit mean and no eigenvalue is super small. Note, G_prime
G_prime_smooth = 0.0001 # is already diagonalized, the variable G_prime is just the tensor of eigenvalues.
# make sure G_prime has no zero eigs, and is unit mean. G_prime_mean = _mean(G_prime, exclude_dims=[0], keepdim=True)
G_prime = ((G_prime + eps + G_prime_smooth * G_prime_mean) / G_prime_min = group["cov_min"][3]
(G_prime_mean * (1+G_prime_smooth) + eps)) # make sure G_prime has no zero eigs, and is unit mean.
# it now has unit mean.. G_prime = ((G_prime + eps + G_prime_min * G_prime_mean) /
G_prime_max = 400.0 (G_prime_mean * (1+G_prime_min) + eps))
G_prime = 1. / (1./G_prime + 1./G_prime_max) # apply max # it now has unit mean..
G_prime_max = group["cov_max"][3]
G_prime = 1. / (1./G_prime + 1./G_prime_max) # apply max
G_prime_pow = group["cov_pow"][3]
if G_prime_pow != 1.0:
G_prime = G_prime ** G_prime_pow
G_prime_rms = G_prime.sqrt() G_prime_rms = G_prime.sqrt()
G_prime_scale = G_prime_rms.unsqueeze(-1) * G_prime_rms.unsqueeze(-2) G_prime_scale = G_prime_rms.unsqueeze(-1) * G_prime_rms.unsqueeze(-2)
@ -1073,17 +1083,17 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
P_gnorm = P_prime * G_prime_scale P_gnorm = P_prime * G_prime_scale
# Apply another round of smoothing "relative to G" # Apply another round of smoothing "relative to G"
P_gnorm = self._smooth_cov(P_gnorm, P_gnorm = self._smooth_cov(P_gnorm,
group["param_cov_min"][1], group["cov_min"][1],
group["param_cov_max"][1], group["cov_max"][1],
group["param_pow"][1]) group["cov_pow"][1])
# Undo the scaling relative to G, so we have stage-2-smoothed version of P_prime. # Undo the scaling relative to G, so we have stage-2-smoothed version of P_prime.
P_prime = P_gnorm / G_prime_scale P_prime = P_gnorm / G_prime_scale
# Apply a 3rd round of smoothing in the canonical basis. # Apply a 3rd round of smoothing in the canonical basis.
P_prime = self._smooth_cov(P_prime, P_prime = self._smooth_cov(P_prime,
group["param_cov_min"][2], group["cov_min"][2],
group["param_cov_max"][2], group["cov_max"][2],
group["param_pow"][2]) group["cov_pow"][2])
return P_prime return P_prime
def _smooth_cov(self, def _smooth_cov(self,
@ -1378,66 +1388,15 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
while Q.ndim < M.ndim: while Q.ndim < M.ndim:
Q = Q.unsqueeze(2) Q = Q.unsqueeze(2)
# now Q has shape (batch_size, num_blocks, 1, 1, block_size, block_size) # now Q has shape (batch_size, num_blocks, 1, 1, block_size, block_size)
M = torch.matmul(M, Q) # (batch_size, num_blocks, x, y, z, block_size) if M.ndim < Q.ndim: # special case where M shape is e.g. (batch_size, num_blocks, x)
M = torch.matmul(M.unsqueeze(-2), Q).squeeze(-2)
else:
M = torch.matmul(M, Q) # (batch_size, num_blocks, x, y, z, block_size)
M = _move_dim(M, 1, -2) # (batch_size, x, y, z, num_blocks, block_size) M = _move_dim(M, 1, -2) # (batch_size, x, y, z, num_blocks, block_size)
M = M.reshape(*M.shape[:-2], size) # (batch_size, x, y, z, size) M = M.reshape(*M.shape[:-2], size) # (batch_size, x, y, z, size)
M = M.transpose(-1, dim) # (batch_size, x, size, y, z) M = M.transpose(-1, dim) # (batch_size, x, size, y, z)
return M return M
def _smooth_param_rms(self,
group: dict,
rms: Tensor,
rank: int) -> Tensor:
"""
Smooths and normalizes to mean=1 a tensor of shape something like (batch_size, 1, size, 1, 1),
where for the nontrivial dim `size` it contains dagonalized parameter rms values; this is for
one dimension of a multiple-dimensional tensor.
This will be used to construct a learning-rate matrix.
Args:
group: dict for configuration values
rms: A tensor of shape (batch_size, [1,1,..], size, [1,1,1,..]), representing
(for each batch element) a list of root-mean-square values, one per
dimension of the space of size `size`.
rank: the assumed rank of the covariance matrix from which rms was derived,
used to decide how much smoothing to apply. This is actually the
minimal rank, if the parameter matrix were stay fixed during training,
but still relevant to know how robust the parameter covariance estimate is.
Returns:
a Tensor with the same shape as `rms` but with some smoothing applied so there
are no values too close to zero.
"""
smooth0 = group["param_rms_smooth0"]
smooth1 = group["param_rms_smooth1"]
param_cov_max = group["param_cov_max"]
param_pow = group["param_pow"]
eps = group["eps"]
batch_size = rms.shape[0]
size = rms.numel() // batch_size
rms = rms ** param_pow
# want expr to be of the form: smooth = alpha * size / (beta*rank + size)
# from rank==0, we get smooth0 = alpha * size/size, so alpha = smooth0.
# from setting rank==size, we get smooth1 = alpha * size / (beta*size * size) = alpha/(1+beta),
# so smooth1 == smooth0 / (1+beta), so (1+beta) = smooth0/smooth1, so beta=smooth0/smooth1 - 1
smooth = smooth0 * size / ((smooth0/smooth1 - 1) * rank + size)
mean = _mean(rms, exclude_dims=[0], keepdim=True)
rms += eps + smooth * mean
new_mean = (eps + (smooth + 1) * mean) # mean of modified rms.
ans = rms / new_mean
# Apply a `soft min` of param_cov_max via the formula
# softmin(x,y) = 1/(1/x + 1/y).
ans = 1. / (1. / ans + 1. / param_cov_max)
# and renormalize to mean=1.
ans /= _mean(ans, exclude_dims=[0], keepdim=True)
return ans
def _move_dim(x: Tensor, orig_dim: int, new_dim: int) -> Tensor: def _move_dim(x: Tensor, orig_dim: int, new_dim: int) -> Tensor:
""" """
@ -2197,9 +2156,8 @@ def _test_eve_cain():
for iter in [3, 2]: for iter in [3, 2]:
fix_random_seed(42) fix_random_seed(42)
Linear = torch.nn.Linear if iter == 0 else ScaledLinear Linear = torch.nn.Linear if iter == 0 else ScaledLinear
# TODO: find out why this is not converging...
hidden_dim = 512 hidden_dim = 768
m = torch.nn.Sequential(Linear(E, hidden_dim), m = torch.nn.Sequential(Linear(E, hidden_dim),
torch.nn.PReLU(), torch.nn.PReLU(),
Linear(hidden_dim, hidden_dim), Linear(hidden_dim, hidden_dim),