Cleanup and refactoring

This commit is contained in:
Daniel Povey 2022-07-24 05:48:38 +08:00
parent 8a9bbb93bc
commit 6290fcb535

View File

@ -109,7 +109,7 @@ class PrAdam(BatchedOptimizer):
is the scaling factor on the learning rate of p_scale. is the scaling factor on the learning rate of p_scale.
param_pow: Power on the parameter covariance matrix, 1.0 means learn proportional param_pow: Power on the parameter covariance matrix, 1.0 means learn proportional
to parameter rms (1.0 will be too much, should be between 0 and 1.) to parameter rms (1.0 will be too much, should be between 0 and 1.)
This is one of the most important tunable factors, along with max_lr_factor. This is one of the most important tunable factors, along with param_cov_max.
param_rms_smooth0: Limiting value of smoothing proportion for parameter matrix, as param_rms_smooth0: Limiting value of smoothing proportion for parameter matrix, as
assumed rank of param covariance [==product of sizes on the other assumed rank of param covariance [==product of sizes on the other
tensor dims] approaches 0. tensor dims] approaches 0.
@ -117,24 +117,41 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
param covariance equals the dimension of the covaraince matrix. param covariance equals the dimension of the covaraince matrix.
param_rms_smooth{0,1} determine the smoothing proportions for other param_rms_smooth{0,1} determine the smoothing proportions for other
conditions. conditions.
max_lr_factor: How much faster we allow any direction in parameter space to learn faster param_cov_min: [IMPORTANT] A 3-tuple of minimums of the diagonal values of the parameter
than the mean... this is a relatively important thing to tune, covariance, normalized in 3 different ways: relative to its own
along with param_pow. diagonal, scaled by the grad covariance, and in the canonical basis.
eps: An epsilon to prevent division by zero With param_cov_max, defines how "aggressive" we allow our update to
be.
param_cov_max: [IMPORTANT] A 3-tuple of maximums of the diagonal values of the parameter
covariance, normalized in 3 different ways: relative to its own
diagonal, scaled by the grad covariance, and in the canonical basis.
param_pow: This was mainly added for development and experimentation purposes;
it allows you to smooth the parameter covariance by taking it to
a power, but if this is not 1.0 it will cause a speed penalty because
it requires SVD.
eps: A general-purpose epsilon to prevent division by zero
param_min_rms: Minimum root-mean-square value of parameter tensor, for purposes of param_min_rms: Minimum root-mean-square value of parameter tensor, for purposes of
learning the scale on the parameters (we'll keep it >= this size) learning the scale on the parameters (we'll constrain the rms of each non-scalar
parameter tensor to be >= this value)
param_max_rms: Maximum root-mean-square value of parameter tensor, for purposes of param_max_rms: Maximum root-mean-square value of parameter tensor, for purposes of
learning the scale on the parameters (we'll keep it <= this size) learning the scale on the parameters (we'll constrain the rms of each non-scalar
scalar_max: Maximum absolute value for scalar parameters parameter tensor to be <= this value)
scalar_max: Maximum absolute value for scalar parameters (applicable if your
model has any parameters with numel() == 1)
size_update_period: The periodicity, in steps, with which we update the size (scale) size_update_period: The periodicity, in steps, with which we update the size (scale)
of the parameter tensor. This is provided to save a little time. of the parameter tensor. This is provided to save a little time
lr_update_period: Determines the periodicity, in steps, with which we update the in the update.
learning-rate matrices. The first number is the periodicity at lr_update_period: [IMPORTANT]: A 2-tuple of ints that Determines the periodicity, in steps, with
which we update the learning-rate matrices. The first number is the periodicity at
the start of training, the second number is the periodicity the start of training, the second number is the periodicity
later in training. One step of updating the learning rate matrices later in training, and we gradually increase from one to the other.
can take as long as over 50 minibatches, because SVD on GPU is slow. The reason for such a complicated schedule is that updating the learning
** This is important for the speed/optimizaton tradeoff. ** rate matrices is very slow, principally because it requires SVD, and SVD
max_block_size: The maximum block size in block-diagonal co-ordinate transformations. seems to have quite slow implementations.
max_block_size: [IMPORTANT] The maximum block size in block-diagonal co-ordinate
transformations. You can probably set this to 512 or 1024. Larger
values will require more MEMORY and may be a bit slower, but should
lead to better optimization performance.
""" """
def __init__( def __init__(
self, self,
@ -142,8 +159,8 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
lr=3e-02, lr=3e-02,
betas=(0.9, 0.98), betas=(0.9, 0.98),
size_lr_scale=0.1, size_lr_scale=0.1,
min_lr_factor=(0.05, 0.01, 0.01), param_cov_min=(0.05, 0.01, 0.01),
max_lr_factor=(10.0, 40.0, 10.0), param_cov_max=(10.0, 40.0, 10.0),
param_pow=(1.0, 1.0, 1.0), param_pow=(1.0, 1.0, 1.0),
param_rms_smooth0=0.4, param_rms_smooth0=0.4,
param_rms_smooth1=0.2, param_rms_smooth1=0.2,
@ -165,8 +182,8 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
defaults = dict( defaults = dict(
lr=lr, lr=lr,
size_lr_scale=size_lr_scale, size_lr_scale=size_lr_scale,
min_lr_factor=min_lr_factor, param_cov_min=param_cov_min,
max_lr_factor=max_lr_factor, param_cov_max=param_cov_max,
param_pow=param_pow, param_pow=param_pow,
param_rms_smooth0=param_rms_smooth0, param_rms_smooth0=param_rms_smooth0,
param_rms_smooth1=param_rms_smooth1, param_rms_smooth1=param_rms_smooth1,
@ -649,12 +666,12 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
# so we need to transpose Q as we convert M to the diagonalized co-ordinate. # so we need to transpose Q as we convert M to the diagonalized co-ordinate.
M = torch.matmul(M, Q.transpose(-2, -1)) # (batch_size, num_blocks, x, z, y, block_size) M = torch.matmul(M, Q.transpose(-2, -1)) # (batch_size, num_blocks, x, z, y, block_size)
M = _move_dim(M, 1, -2) # (batch_size, x, z, y, num_blocks, block_size) M = _move_dim(M, 1, -2) # (batch_size, x, z, y, num_blocks, block_size)
M = M.reshape(*M.shape[:-2], size) # # (batch_size, x, z, y, size) M = M.reshape(*M.shape[:-2], size) # (batch_size, x, z, y, size)
cur_p = M.transpose(dim, -1) # (batch_size, x, size, y, z) cur_p = M.transpose(dim, -1) # (batch_size, x, size, y, z)
# cur_param_var is a diagonal parameter variance over dimension `dim`, # cur_param_var is a diagonal parameter variance over dimension `dim`,
# of the current "slightly-whitened" parameter; it # of the current "slightly-whitened" parameter; it
# will have shape something like [1, size, 1]; or [batch_size, 1, size, 1]. # will have shape [batch_size, 1, size, 1].
cur_param_var = _mean(cur_p**2, cur_param_var = _mean(cur_p**2,
exclude_dims=[0,dim], exclude_dims=[0,dim],
keepdim=True) # (batch_size, 1, size, 1, 1) if dim==2 keepdim=True) # (batch_size, 1, size, 1, 1) if dim==2
@ -848,6 +865,7 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
# It is not smoothed yet. # It is not smoothed yet.
P_prime = torch.matmul(U_g.transpose(2, 3), torch.matmul(param_cov, U_g)) P_prime = torch.matmul(U_g.transpose(2, 3), torch.matmul(param_cov, U_g))
P_prime_unsmoothed = P_prime
P_prime = self._smooth_param_cov(group, p_shape, P_prime, G_prime) P_prime = self._smooth_param_cov(group, p_shape, P_prime, G_prime)
# C will satisfy: P_prime == torch.matmul(C, C.transpose(2, 3)) # C will satisfy: P_prime == torch.matmul(C, C.transpose(2, 3))
@ -900,7 +918,7 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
U_z, S, _ = _svd(Z) U_z, S, _ = _svd(Z)
if True: if True:
skip = 10 if S.shape[-1] > 40 else 1 skip = 10 if S.shape[-1] > 40 else 1
logging.info(f"Eigs of Z are: {S[0,0,::skip]}") logging.info(f"dim={dim}, G_prime is {G_prime[0,0,::skip]}, Eigs of Z are: {S[0,0,::skip]}")
# state[f"Q_{dim}"] is indexed: [batch_idx, block_idx, diagonalized_coordinate, canonical_coordinate]. # state[f"Q_{dim}"] is indexed: [batch_idx, block_idx, diagonalized_coordinate, canonical_coordinate].
# so we need to transpose U_z as U_z is indexed # so we need to transpose U_z as U_z is indexed
@ -917,8 +935,11 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
this_P_proj = _diag(torch.matmul(U_prod, torch.matmul(P_prime, U_prod.transpose(2, 3)))) this_P_proj = _diag(torch.matmul(U_prod, torch.matmul(P_prime, U_prod.transpose(2, 3))))
P_proj[dim] = this_P_proj.clone().reshape(batch_size, size) P_proj[dim] = this_P_proj.clone().reshape(batch_size, size)
if True: if True:
this_P_proj_unsmoothed = _diag(torch.matmul(U_prod, torch.matmul(P_prime_unsmoothed,
U_prod.transpose(2, 3))))
this_P_proj_unsmoothed = this_P_proj_unsmoothed.clone().reshape(batch_size, size)
skip = 10 if P_proj[dim].shape[-1] > 40 else 1 skip = 10 if P_proj[dim].shape[-1] > 40 else 1
logging.info(f"Eigs of P_proj are: {P_proj[dim][0,::skip]}") logging.info(f"dim={dim}, diag of P_proj is: {P_proj[dim][0,::skip]}, diag of unsmoothed P_proj is: {this_P_proj_unsmoothed[0,::skip]}")
return P_proj return P_proj
@ -989,15 +1010,15 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
_diag(P_norm).add_(smooth) _diag(P_norm).add_(smooth)
P_norm = self._smooth_cov(P_norm, P_norm = self._smooth_cov(P_norm,
group["min_lr_factor"][0], group["param_cov_min"][0],
group["max_lr_factor"][0], group["param_cov_max"][0],
group["param_pow"][0]) group["param_pow"][0])
# Remove the diagonal preconditioning on P_norm, giving us stage-1-smoothed # Remove the diagonal preconditioning on P_norm, giving us stage-1-smoothed
# version of P_prime. # version of P_prime.
P_prime = P_norm * P_prime_scale P_prime = P_norm * P_prime_scale
# Make sure G_prime has unit mean and no eigenvalue is super small. Note, G_prime # Make sure G_prime has unit mean and no eigenvalue is super small. Note, G_prime
# is already diagonal. # is already diagonalized, the variable G_prime is just the tensor of eigenvalues.
G_prime_mean = _mean(G_prime, exclude_dims=[0], keepdim=True) G_prime_mean = _mean(G_prime, exclude_dims=[0], keepdim=True)
G_prime_smooth = 0.001 G_prime_smooth = 0.001
# make sure G_prime has no zero eigs, and is unit mean. # make sure G_prime has no zero eigs, and is unit mean.
@ -1012,16 +1033,16 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
P_gnorm = P_prime * G_prime_scale P_gnorm = P_prime * G_prime_scale
# Apply another round of smoothing "relative to G" # Apply another round of smoothing "relative to G"
P_gnorm = self._smooth_cov(P_gnorm, P_gnorm = self._smooth_cov(P_gnorm,
group["min_lr_factor"][1], group["param_cov_min"][1],
group["max_lr_factor"][1], group["param_cov_max"][1],
group["param_pow"][1]) group["param_pow"][1])
# Undo the scaling relative to G, so we have stage-2-smoothed version of P_prime. # Undo the scaling relative to G, so we have stage-2-smoothed version of P_prime.
P_prime = P_gnorm / G_prime_scale P_prime = P_gnorm / G_prime_scale
# Apply a 3rd round of smoothing in the canonical basis. # Apply a 3rd round of smoothing in the canonical basis.
P_prime = self._smooth_cov(P_prime, P_prime = self._smooth_cov(P_prime,
group["min_lr_factor"][2], group["param_cov_min"][2],
group["max_lr_factor"][2], group["param_cov_max"][2],
group["param_pow"][2]) group["param_pow"][2])
return P_prime return P_prime
@ -1349,7 +1370,7 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
""" """
smooth0 = group["param_rms_smooth0"] smooth0 = group["param_rms_smooth0"]
smooth1 = group["param_rms_smooth1"] smooth1 = group["param_rms_smooth1"]
max_lr_factor = group["max_lr_factor"] param_cov_max = group["param_cov_max"]
param_pow = group["param_pow"] param_pow = group["param_pow"]
eps = group["eps"] eps = group["eps"]
batch_size = rms.shape[0] batch_size = rms.shape[0]
@ -1369,9 +1390,9 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
ans = rms / new_mean ans = rms / new_mean
# Apply a `soft min` of max_lr_factor via the formula # Apply a `soft min` of param_cov_max via the formula
# softmin(x,y) = 1/(1/x + 1/y). # softmin(x,y) = 1/(1/x + 1/y).
ans = 1. / (1. / ans + 1. / max_lr_factor) ans = 1. / (1. / ans + 1. / param_cov_max)
# and renormalize to mean=1. # and renormalize to mean=1.
ans /= _mean(ans, exclude_dims=[0], keepdim=True) ans /= _mean(ans, exclude_dims=[0], keepdim=True)