this version performs way worse but has bugs fixed, can optimize from here.

This commit is contained in:
Daniel Povey 2022-07-23 08:11:20 +08:00
parent dd10eb140f
commit dee496145d

View File

@ -143,7 +143,7 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
betas=(0.9, 0.98), betas=(0.9, 0.98),
size_lr_scale=0.1, size_lr_scale=0.1,
min_lr_factor=(0.05, 0.05, 0.05), min_lr_factor=(0.05, 0.05, 0.05),
max_lr_factor=(10.0, 10.0, 10.0), max_lr_factor=(100.0, 100.0, 100.0),
param_rms_smooth0=0.75, param_rms_smooth0=0.75,
param_rms_smooth1=0.25, param_rms_smooth1=0.25,
eps=1.0e-08, eps=1.0e-08,
@ -617,7 +617,8 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
# Now Q is of shape (batch_size, num_blocks, 1, 1, block_size, block_size) # Now Q is of shape (batch_size, num_blocks, 1, 1, block_size, block_size)
# [batch_index, block_index, diagonalized_coordinate, canonical_coordinate], # [batch_index, block_index, diagonalized_coordinate, canonical_coordinate],
# so we need to transpose Q as we convert M to the diagonalized co-ordinate. # so we need to transpose Q as we convert M to the diagonalized co-ordinate.
M = torch.matmul(M, Q.transpose(2, 3)) # (batch_size, num_blocks, x, y, z, block_size) #M = torch.matmul(M, Q.transpose(2, 3)) # (batch_size, num_blocks, x, y, z, block_size)
M = torch.matmul(M, Q.transpose(-2, -1)) # (batch_size, num_blocks, x, y, z, block_size)
M = _move_dim(M, 1, -2) # (batch_size, x, y, z, num_blocks, block_size) M = _move_dim(M, 1, -2) # (batch_size, x, y, z, num_blocks, block_size)
M = M.reshape(*M.shape[:-2], size) # # (batch_size, x, y, z, size) M = M.reshape(*M.shape[:-2], size) # # (batch_size, x, y, z, size)
cur_p = M.transpose(dim, -1) # (batch_size, x, size, y, z) cur_p = M.transpose(dim, -1) # (batch_size, x, size, y, z)
@ -638,8 +639,17 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
# spectrum"). We scale p so that it matches the accumulated stats, # spectrum"). We scale p so that it matches the accumulated stats,
# the idea is to ensure it doesn't have any too-small eigenvalues # the idea is to ensure it doesn't have any too-small eigenvalues
# (where the stats permit). # (where the stats permit).
scale = (S / cur_param_var.clamp(min=eps)).sqrt() scale = (S / cur_param_var.clamp(min=eps)).sqrt()
if True:
S_tmp = S.reshape(batch_size, size)
cur_tmp = cur_param_var.reshape(batch_size, size)
scale_tmp = scale.reshape(batch_size, size)
skip = 10 if size > 40 else 1
logging.info(f"cur_param_var={cur_tmp[0][::skip]}, S={S_tmp[0][::skip]}, scale={scale_tmp[0][::skip]}")
if random.random() < 0.01: if random.random() < 0.01:
skip = 10 if size < 20 else 1 skip = 10 if size < 20 else 1
logging.info(f"shape={p.shape}, dim={dim}, scale={scale[0].flatten()[::skip]}, cur_param_var={cur_param_var[0].flatten()[::skip]}, S={S[0].flatten()[::skip]}") logging.info(f"shape={p.shape}, dim={dim}, scale={scale[0].flatten()[::skip]}, cur_param_var={cur_param_var[0].flatten()[::skip]}, S={S[0].flatten()[::skip]}")
@ -752,7 +762,7 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
be solving (eqn:2), and then computing: be solving (eqn:2), and then computing:
Z = U Z' U^T. Z = U Z' U^T.
A solution to (eqn:1) is as follows. We are going to be using a Cholesky-based solution in A solution to (eqn:2) is as follows. We are going to be using a Cholesky-based solution in
favor of one that requires SVD or eigenvalue decomposition, because it is much faster (we first favor of one that requires SVD or eigenvalue decomposition, because it is much faster (we first
have to be careful that the input is not close to singular, though). have to be careful that the input is not close to singular, though).
@ -801,14 +811,21 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
P_prime = self._smooth_param_cov(group, p_shape, P_prime, G_prime) P_prime = self._smooth_param_cov(group, p_shape, P_prime, G_prime)
C = P_prime.cholesky() # P_prime = torch.matmul(C, C.transpose(2, 3)) # C will satisfy: P_prime == torch.matmul(C, C.transpose(2, 3))
# C is of shape (batch_size, num_blocks, block_size, block_size).
#def _fake_cholesky(X):
# U_, S_, _ = _svd(X)
# return U_ * S_.sqrt().unsqueeze(-2)
#C = _fake_cholesky(P_prime)
C = P_prime.cholesky()
# CGC = (C^T G' C) which would be torch.matmul(C.transpose(2, 3), torch.matmul(G, C))
# CGC = (C^T G' C), it would be torch.matmul(C.transpose(2, 3), torch.matmul(G, C))
# if G were formatted as a diagonal matrix, but G is just the diagonal. # if G were formatted as a diagonal matrix, but G is just the diagonal.
CGC = torch.matmul(C.transpose(2, 3) * G_prime.unsqueeze(-2), CGC = torch.matmul(C.transpose(2, 3) * G_prime.unsqueeze(-2), C)
C)
U, S, _ = _svd(CGC) # Need SVD to compute CGC^{-0.5} U, S, _ = _svd(CGC) # Need SVD to compute CGC^{-0.5}
# next we compute (eqn:3). The thing in the parenthesis is, GCC^{-0.5}, # next we compute (eqn:3). The thing in the parenthesis is, CGC^{-0.5},
# can be written as U S^{-0.5} U^T, so the whole thing is # can be written as U S^{-0.5} U^T, so the whole thing is
# (C U) S^{-0.5} (C U)^T # (C U) S^{-0.5} (C U)^T
CU = torch.matmul(C, U) CU = torch.matmul(C, U)
@ -817,12 +834,26 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
CU.transpose(2, 3)) CU.transpose(2, 3))
if True: if True:
def _check_similar(x, y, name):
ratio = (y-x).abs().sum() / x.abs().sum()
if ratio > 0.0001:
logging.warn(f"Check {name} failed, ratio={ratio.item()}")
def _check_symmetric(x, x_name):
diff = x - x.transpose(-2, -1)
ratio = diff.abs().sum() / x.abs().sum()
if ratio > 0.0001:
logging.warn(f"{x_name} is not symmetric: ratio={ratio.item()}")
_check_similar(P_prime, torch.matmul(C, C.transpose(2, 3)), "cholesky_check")
_check_symmetric(Z_prime, "Z_prime")
_check_symmetric(P_prime, "P_prime")
# A check. # A check.
# Z_prime is supposed to satisfy: Z_prime G_prime Z_prime = P_prime (eqn:2) # Z_prime is supposed to satisfy: Z_prime G_prime Z_prime = P_prime (eqn:2)
P_prime_check = torch.matmul(Z_prime * G_prime.unsqueeze(-2), Z_prime) P_prime_check = torch.matmul(Z_prime * G_prime.unsqueeze(-2), Z_prime)
diff_ratio = (P_prime - P_prime_check).abs().sum() / P_prime.abs().sum() diff_ratio = ((P_prime - P_prime_check) ** 2).sum().sqrt() // (P_prime ** 2).sum().sqrt()
if diff_ratio > 0.01: if diff_ratio > 0.001:
logging.warn(f"Z_prime does not satisfy its definition, diff_ratio = {diff_ratio}") logging.warn(f"Z_prime does not satisfy its definition, diff_ratio = {diff_ratio}, size={size}")
Z = torch.matmul(U_g, torch.matmul(Z_prime, U_g.transpose(2, 3))) Z = torch.matmul(U_g, torch.matmul(Z_prime, U_g.transpose(2, 3)))
# OK, Z is the SPD transform that maps G to P, as in Z G Z = P. # OK, Z is the SPD transform that maps G to P, as in Z G Z = P.
# We just need the basis that diagonalizes this. # We just need the basis that diagonalizes this.
@ -844,7 +875,11 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
U_prod = torch.matmul(U_z.transpose(2, 3), U_g) U_prod = torch.matmul(U_z.transpose(2, 3), U_g)
# this_P_proj shape: (batch_size, num_blocks, block_size) # this_P_proj shape: (batch_size, num_blocks, block_size)
this_P_proj = _diag(torch.matmul(U_prod, torch.matmul(P_prime, U_prod.transpose(2, 3)))) this_P_proj = _diag(torch.matmul(U_prod, torch.matmul(P_prime, U_prod.transpose(2, 3))))
P_proj[dim] = this_P_proj.reshape(batch_size, size) P_proj[dim] = this_P_proj.clone().reshape(batch_size, size)
if True:
skip = 10 if P_proj[dim].shape[-1] > 40 else 1
logging.info(f"Eigs of P_proj are: {P_proj[dim][0,::skip]}")
return P_proj return P_proj
@ -890,7 +925,7 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
# Now P is as normalized as we can make it... do smoothing baserd on 'rank', # Now P is as normalized as we can make it... do smoothing baserd on 'rank',
# that is intended to compensate for bad estimates of P. # that is intended to compensate for bad estimates of P.
batch_size = p_shape[0] batch_size = p_shape[0]
size = P_prime.shape[0] # size of dim we are concerned with right now size = P_norm.shape[0] # size of dim we are concerned with right now
# `rank` is the rank of P_prime if we were to estimate it from just one # `rank` is the rank of P_prime if we were to estimate it from just one
# parameter tensor. We average it over time, but actually it won't be changing # parameter tensor. We average it over time, but actually it won't be changing
# too much, so `rank` does tell us something. # too much, so `rank` does tell us something.
@ -908,7 +943,7 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
# add rank-dependent smoothing amount to diagonal of P_prime. _diag() returns an aliased tensor. # add rank-dependent smoothing amount to diagonal of P_prime. _diag() returns an aliased tensor.
# we don't need to multiply `smooth` by anything, because at this point, P_prime should have # we don't need to multiply `smooth` by anything, because at this point, P_prime should have
# diagonal elements close to 1. # diagonal elements close to 1.
_diag(P_prime).add_(smooth) _diag(P_norm).add_(smooth)
P_norm = self._smooth_cov(P_norm, P_norm = self._smooth_cov(P_norm,
group["min_lr_factor"][0], group["min_lr_factor"][0],
@ -927,7 +962,7 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
G_prime_rms = G_prime.sqrt() G_prime_rms = G_prime.sqrt()
G_prime_scale = G_prime_rms.unsqueeze(-1) * G_prime_rms.unsqueeze(-2) G_prime_scale = G_prime_rms.unsqueeze(-1) * G_prime_rms.unsqueeze(-2)
# P_gnorm is a version of P_prime that is scaled relative to G, i.e. # P_gnorm is a version of P_prime that is scaled relative to G, i.e.
# scaled in such a way that would make G the unit matrix. # scaled in a way that would make G the unit matrix.
P_gnorm = P_prime / G_prime_scale P_gnorm = P_prime / G_prime_scale
# Apply another round of smoothing "relative to G" # Apply another round of smoothing "relative to G"
P_gnorm = self._smooth_cov(P_gnorm, P_gnorm = self._smooth_cov(P_gnorm,
@ -982,6 +1017,7 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
S = S / _mean(S, exclude_dims=[0], keepdim=True) S = S / _mean(S, exclude_dims=[0], keepdim=True)
return torch.matmul(U * S.unsqueeze(-2), U.transpose(2, 3)) return torch.matmul(U * S.unsqueeze(-2), U.transpose(2, 3))
else: else:
X = X.clone() # may be
diag = _diag(X) # Aliased with X diag = _diag(X) # Aliased with X
mean_eig = _mean(diag, exclude_dims=[0], keepdim=True) mean_eig = _mean(diag, exclude_dims=[0], keepdim=True)
eps = 1.0e-10 # prevent division by zero eps = 1.0e-10 # prevent division by zero
@ -1335,11 +1371,11 @@ def _diag(x: Tensor):
elif x.ndim == 4: elif x.ndim == 4:
(B, C, M, M2) = x.shape (B, C, M, M2) = x.shape
assert M == M2 assert M == M2
ans = x.as_strided(size=(B, C, M), stride=(stride[0], stride[1], stride[2] + stride[3])).contiguous() ans = x.as_strided(size=(B, C, M), stride=(stride[0], stride[1], stride[2] + stride[3]))
elif x.ndim == 2: elif x.ndim == 2:
(M, M2) = x.shape (M, M2) = x.shape
assert M == M2 assert M == M2
ans = x.as_strided(size=(M,), stride=(stride[0] + stride[1],)).contiguous() ans = x.as_strided(size=(M,), stride=(stride[0] + stride[1],))
return ans return ans