Cleanup and refactoring

This commit is contained in:
Daniel Povey 2022-07-24 05:48:38 +08:00
parent 8a9bbb93bc
commit 6290fcb535

View File

@ -109,7 +109,7 @@ class PrAdam(BatchedOptimizer):
is the scaling factor on the learning rate of p_scale.
param_pow: Power on the parameter covariance matrix, 1.0 means learn proportional
to parameter rms (1.0 will be too much, should be between 0 and 1.)
This is one of the most important tunable factors, along with max_lr_factor.
This is one of the most important tunable factors, along with param_cov_max.
param_rms_smooth0: Limiting value of smoothing proportion for parameter matrix, as
assumed rank of param covariance [==product of sizes on the other
tensor dims] approaches 0.
@ -117,24 +117,41 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
param covariance equals the dimension of the covaraince matrix.
param_rms_smooth{0,1} determine the smoothing proportions for other
conditions.
max_lr_factor: How much faster we allow any direction in parameter space to learn faster
than the mean... this is a relatively important thing to tune,
along with param_pow.
eps: An epsilon to prevent division by zero
param_cov_min: [IMPORTANT] A 3-tuple of minimums of the diagonal values of the parameter
covariance, normalized in 3 different ways: relative to its own
diagonal, scaled by the grad covariance, and in the canonical basis.
With param_cov_max, defines how "aggressive" we allow our update to
be.
param_cov_max: [IMPORTANT] A 3-tuple of maximums of the diagonal values of the parameter
covariance, normalized in 3 different ways: relative to its own
diagonal, scaled by the grad covariance, and in the canonical basis.
param_pow: This was mainly added for development and experimentation purposes;
it allows you to smooth the parameter covariance by taking it to
a power, but if this is not 1.0 it will cause a speed penalty because
it requires SVD.
eps: A general-purpose epsilon to prevent division by zero
param_min_rms: Minimum root-mean-square value of parameter tensor, for purposes of
learning the scale on the parameters (we'll keep it >= this size)
learning the scale on the parameters (we'll constrain the rms of each non-scalar
parameter tensor to be >= this value)
param_max_rms: Maximum root-mean-square value of parameter tensor, for purposes of
learning the scale on the parameters (we'll keep it <= this size)
scalar_max: Maximum absolute value for scalar parameters
learning the scale on the parameters (we'll constrain the rms of each non-scalar
parameter tensor to be <= this value)
scalar_max: Maximum absolute value for scalar parameters (applicable if your
model has any parameters with numel() == 1)
size_update_period: The periodicity, in steps, with which we update the size (scale)
of the parameter tensor. This is provided to save a little time.
lr_update_period: Determines the periodicity, in steps, with which we update the
learning-rate matrices. The first number is the periodicity at
of the parameter tensor. This is provided to save a little time
in the update.
lr_update_period: [IMPORTANT]: A 2-tuple of ints that Determines the periodicity, in steps, with
which we update the learning-rate matrices. The first number is the periodicity at
the start of training, the second number is the periodicity
later in training. One step of updating the learning rate matrices
can take as long as over 50 minibatches, because SVD on GPU is slow.
** This is important for the speed/optimizaton tradeoff. **
max_block_size: The maximum block size in block-diagonal co-ordinate transformations.
later in training, and we gradually increase from one to the other.
The reason for such a complicated schedule is that updating the learning
rate matrices is very slow, principally because it requires SVD, and SVD
seems to have quite slow implementations.
max_block_size: [IMPORTANT] The maximum block size in block-diagonal co-ordinate
transformations. You can probably set this to 512 or 1024. Larger
values will require more MEMORY and may be a bit slower, but should
lead to better optimization performance.
"""
def __init__(
self,
@ -142,8 +159,8 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
lr=3e-02,
betas=(0.9, 0.98),
size_lr_scale=0.1,
min_lr_factor=(0.05, 0.01, 0.01),
max_lr_factor=(10.0, 40.0, 10.0),
param_cov_min=(0.05, 0.01, 0.01),
param_cov_max=(10.0, 40.0, 10.0),
param_pow=(1.0, 1.0, 1.0),
param_rms_smooth0=0.4,
param_rms_smooth1=0.2,
@ -165,8 +182,8 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
defaults = dict(
lr=lr,
size_lr_scale=size_lr_scale,
min_lr_factor=min_lr_factor,
max_lr_factor=max_lr_factor,
param_cov_min=param_cov_min,
param_cov_max=param_cov_max,
param_pow=param_pow,
param_rms_smooth0=param_rms_smooth0,
param_rms_smooth1=param_rms_smooth1,
@ -649,12 +666,12 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
# so we need to transpose Q as we convert M to the diagonalized co-ordinate.
M = torch.matmul(M, Q.transpose(-2, -1)) # (batch_size, num_blocks, x, z, y, block_size)
M = _move_dim(M, 1, -2) # (batch_size, x, z, y, num_blocks, block_size)
M = M.reshape(*M.shape[:-2], size) # # (batch_size, x, z, y, size)
M = M.reshape(*M.shape[:-2], size) # (batch_size, x, z, y, size)
cur_p = M.transpose(dim, -1) # (batch_size, x, size, y, z)
# cur_param_var is a diagonal parameter variance over dimension `dim`,
# of the current "slightly-whitened" parameter; it
# will have shape something like [1, size, 1]; or [batch_size, 1, size, 1].
# will have shape [batch_size, 1, size, 1].
cur_param_var = _mean(cur_p**2,
exclude_dims=[0,dim],
keepdim=True) # (batch_size, 1, size, 1, 1) if dim==2
@ -848,6 +865,7 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
# It is not smoothed yet.
P_prime = torch.matmul(U_g.transpose(2, 3), torch.matmul(param_cov, U_g))
P_prime_unsmoothed = P_prime
P_prime = self._smooth_param_cov(group, p_shape, P_prime, G_prime)
# C will satisfy: P_prime == torch.matmul(C, C.transpose(2, 3))
@ -900,7 +918,7 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
U_z, S, _ = _svd(Z)
if True:
skip = 10 if S.shape[-1] > 40 else 1
logging.info(f"Eigs of Z are: {S[0,0,::skip]}")
logging.info(f"dim={dim}, G_prime is {G_prime[0,0,::skip]}, Eigs of Z are: {S[0,0,::skip]}")
# state[f"Q_{dim}"] is indexed: [batch_idx, block_idx, diagonalized_coordinate, canonical_coordinate].
# so we need to transpose U_z as U_z is indexed
@ -917,8 +935,11 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
this_P_proj = _diag(torch.matmul(U_prod, torch.matmul(P_prime, U_prod.transpose(2, 3))))
P_proj[dim] = this_P_proj.clone().reshape(batch_size, size)
if True:
this_P_proj_unsmoothed = _diag(torch.matmul(U_prod, torch.matmul(P_prime_unsmoothed,
U_prod.transpose(2, 3))))
this_P_proj_unsmoothed = this_P_proj_unsmoothed.clone().reshape(batch_size, size)
skip = 10 if P_proj[dim].shape[-1] > 40 else 1
logging.info(f"Eigs of P_proj are: {P_proj[dim][0,::skip]}")
logging.info(f"dim={dim}, diag of P_proj is: {P_proj[dim][0,::skip]}, diag of unsmoothed P_proj is: {this_P_proj_unsmoothed[0,::skip]}")
return P_proj
@ -989,15 +1010,15 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
_diag(P_norm).add_(smooth)
P_norm = self._smooth_cov(P_norm,
group["min_lr_factor"][0],
group["max_lr_factor"][0],
group["param_cov_min"][0],
group["param_cov_max"][0],
group["param_pow"][0])
# Remove the diagonal preconditioning on P_norm, giving us stage-1-smoothed
# version of P_prime.
P_prime = P_norm * P_prime_scale
# Make sure G_prime has unit mean and no eigenvalue is super small. Note, G_prime
# is already diagonal.
# is already diagonalized, the variable G_prime is just the tensor of eigenvalues.
G_prime_mean = _mean(G_prime, exclude_dims=[0], keepdim=True)
G_prime_smooth = 0.001
# make sure G_prime has no zero eigs, and is unit mean.
@ -1012,16 +1033,16 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
P_gnorm = P_prime * G_prime_scale
# Apply another round of smoothing "relative to G"
P_gnorm = self._smooth_cov(P_gnorm,
group["min_lr_factor"][1],
group["max_lr_factor"][1],
group["param_cov_min"][1],
group["param_cov_max"][1],
group["param_pow"][1])
# Undo the scaling relative to G, so we have stage-2-smoothed version of P_prime.
P_prime = P_gnorm / G_prime_scale
# Apply a 3rd round of smoothing in the canonical basis.
P_prime = self._smooth_cov(P_prime,
group["min_lr_factor"][2],
group["max_lr_factor"][2],
group["param_cov_min"][2],
group["param_cov_max"][2],
group["param_pow"][2])
return P_prime
@ -1349,7 +1370,7 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
"""
smooth0 = group["param_rms_smooth0"]
smooth1 = group["param_rms_smooth1"]
max_lr_factor = group["max_lr_factor"]
param_cov_max = group["param_cov_max"]
param_pow = group["param_pow"]
eps = group["eps"]
batch_size = rms.shape[0]
@ -1369,9 +1390,9 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
ans = rms / new_mean
# Apply a `soft min` of max_lr_factor via the formula
# Apply a `soft min` of param_cov_max via the formula
# softmin(x,y) = 1/(1/x + 1/y).
ans = 1. / (1. / ans + 1. / max_lr_factor)
ans = 1. / (1. / ans + 1. / param_cov_max)
# and renormalize to mean=1.
ans /= _mean(ans, exclude_dims=[0], keepdim=True)