Fix bug in smooth_cov, for power==1.0

This commit is contained in:
Daniel Povey 2022-07-23 09:06:03 +08:00
parent cc388675a9
commit b47433b77a

View File

@ -144,6 +144,8 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
size_lr_scale=0.1, size_lr_scale=0.1,
min_lr_factor=(0.01, 0.01, 0.01), min_lr_factor=(0.01, 0.01, 0.01),
max_lr_factor=(10.0, 10.0, 10.0), max_lr_factor=(10.0, 10.0, 10.0),
#param_pow=(0.99999, 0.99999, 0.99999),
param_pow=(1.0, 1.0, 1.0),
param_rms_smooth0=0.75, param_rms_smooth0=0.75,
param_rms_smooth1=0.25, param_rms_smooth1=0.25,
eps=1.0e-08, eps=1.0e-08,
@ -163,6 +165,7 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
size_lr_scale=size_lr_scale, size_lr_scale=size_lr_scale,
min_lr_factor=min_lr_factor, min_lr_factor=min_lr_factor,
max_lr_factor=max_lr_factor, max_lr_factor=max_lr_factor,
param_pow=param_pow,
param_rms_smooth0=param_rms_smooth0, param_rms_smooth0=param_rms_smooth0,
param_rms_smooth1=param_rms_smooth1, param_rms_smooth1=param_rms_smooth1,
betas=betas, betas=betas,
@ -678,7 +681,7 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
# cur_scales for the other dims. # cur_scales for the other dims.
cur_scales = [None] * ndim cur_scales = [None] * ndim
debug = (random.random() < 0.001) debug = (random.random() < 0.1)
for i in range(4): # for 4 iterations (this is quite arbitrary) for i in range(4): # for 4 iterations (this is quite arbitrary)
for dim in range(1, ndim): for dim in range(1, ndim):
size = p.shape[dim] size = p.shape[dim]
@ -949,7 +952,8 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
P_norm = self._smooth_cov(P_norm, P_norm = self._smooth_cov(P_norm,
group["min_lr_factor"][0], group["min_lr_factor"][0],
group["max_lr_factor"][0]) group["max_lr_factor"][0],
group["param_pow"][0])
# Remove the diagonal preconditioning on P_norm, giving us stage-1-smoothed # Remove the diagonal preconditioning on P_norm, giving us stage-1-smoothed
# version of P_prime. # version of P_prime.
P_prime = P_norm * P_prime_scale P_prime = P_norm * P_prime_scale
@ -969,14 +973,16 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
# Apply another round of smoothing "relative to G" # Apply another round of smoothing "relative to G"
P_gnorm = self._smooth_cov(P_gnorm, P_gnorm = self._smooth_cov(P_gnorm,
group["min_lr_factor"][1], group["min_lr_factor"][1],
group["max_lr_factor"][1]) group["max_lr_factor"][1],
group["param_pow"][1])
# Undo the scaling relative to G, so we have stage-2-smoothed version of P_prime. # Undo the scaling relative to G, so we have stage-2-smoothed version of P_prime.
P_prime = P_gnorm * G_prime_scale P_prime = P_gnorm * G_prime_scale
# Apply a 3rd round of smoothing # Apply a 3rd round of smoothing
P_prime = self._smooth_cov(P_prime, P_prime = self._smooth_cov(P_prime,
group["min_lr_factor"][2], group["min_lr_factor"][2],
group["max_lr_factor"][2]) group["max_lr_factor"][2],
group["param_pow"][2])
return P_prime return P_prime
def _smooth_cov(self, def _smooth_cov(self,
@ -1007,9 +1013,9 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
X: the batch of symmetric positive definite tensors we are smoothing; X: the batch of symmetric positive definite tensors we are smoothing;
of shape (batch_size, num_blocks, block_size, block_size) of shape (batch_size, num_blocks, block_size, block_size)
""" """
eps = 1.0e-10
if power != 1.0: if power != 1.0:
U, S, _ = _svd(X) U, S, _ = _svd(X)
eps = 1.0e-10
S_mean = _mean(S, exclude_dims=[0], keepdim=True) S_mean = _mean(S, exclude_dims=[0], keepdim=True)
S = S + min_eig * S_mean + eps S = S + min_eig * S_mean + eps
S_mean = S_mean * (1 + min_eig) + eps S_mean = S_mean * (1 + min_eig) + eps
@ -1019,17 +1025,18 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
S = S / _mean(S, exclude_dims=[0], keepdim=True) S = S / _mean(S, exclude_dims=[0], keepdim=True)
return torch.matmul(U * S.unsqueeze(-2), U.transpose(2, 3)) return torch.matmul(U * S.unsqueeze(-2), U.transpose(2, 3))
else: else:
X = X.clone() # may be X = X.clone()
diag = _diag(X) # Aliased with X diag = _diag(X) # Aliased with X
mean_eig = _mean(diag, exclude_dims=[0], keepdim=True) mean_eig = _mean(diag, exclude_dims=[0], keepdim=True)
eps = 1.0e-10 # prevent division by zero
diag += (mean_eig * min_eig + eps) diag += (mean_eig * min_eig + eps)
cur_diag_mean = mean_eig * (1 + min_eig) + eps cur_diag_mean = mean_eig * (1 + min_eig) + eps
# The following 2 statements will be equivalent to: # The following 2 statements will be equivalent to:
# L /= L.mean() # L /= L.mean()
# L = 1 / (1/L + 1/max_eig) # soft-min between L and max_eig # L = 1 / (1/L + 1/max_eig) # soft-min between L and max_eig
# if L is the eigenvalues # if L is the eigenvalues
X = (X.inverse() + 1/(max_eig * cur_diag_mean.unsqueeze(-1))).inverse() X_inv = X.inverse()
_diag(X_inv).add_(1. / (max_eig * cur_diag_mean))
X = X_inv.inverse()
X /= _mean(_diag(X), exclude_dims=[0], keepdim=True).unsqueeze(-1) X /= _mean(_diag(X), exclude_dims=[0], keepdim=True).unsqueeze(-1)
return X return X
@ -2107,6 +2114,17 @@ def _test_eve_cain():
elif iter == 3: optim = PrAdam(m.parameters(), lr=0.03, max_block_size=100) elif iter == 3: optim = PrAdam(m.parameters(), lr=0.03, max_block_size=100)
scheduler = Eden(optim, lr_batches=200, lr_epochs=5, verbose=False) scheduler = Eden(optim, lr_batches=200, lr_epochs=5, verbose=False)
#TEMP
if iter == 3:
a = torch.randn(5, 10, 5, 5)
a = torch.matmul(a, a.transpose(2, 3))
b1 = optim._smooth_cov(a, 0.1, 4.0, 0.999999999)
b2 = optim._smooth_cov(a, 0.1, 4.0, 1.0)
diff = (b1 - b2)
ratio = (diff**2).sqrt().mean() / (b1**2).sqrt().mean()
logging.info(f"ratio = {ratio}")
assert ratio < 0.01
start = timeit.default_timer() start = timeit.default_timer()
avg_loss = 0.0 avg_loss = 0.0
for epoch in range(150): for epoch in range(150):