Rework computation to reduce numerical roundoff

This commit is contained in:
Daniel Povey 2022-07-29 06:22:17 +08:00
parent 633cbd551a
commit 3c1fddaf48

View File

@ -858,6 +858,12 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
(C^T G' C)^{-0.5} C^T G' C (C^T G' C)^{-0.5} = I
which we can immediately see is the case. ]]
It's actually going to be more convenient to compute Z'^{1} (the inverse of Z'), because
ultimately only need the basis that diagonalizes Z, and (C^T G' C) may be singular.
It's not hard to work out that:
Z'^{-1} = C^{-T} (C^T G' C)^0.5 C^{-1} (eqn:4)
Args:
group: dict to look up config values
p_shape: the shape of the batch of identical-sized tensors we are optimizing
@ -908,56 +914,44 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
# if G were formatted as a diagonal matrix, but G is just the diagonal.
CGC = torch.matmul(C.transpose(2, 3) * G_prime.unsqueeze(-2), C)
U, S, _ = _svd(CGC) # Need SVD to compute CGC^{-0.5}
# next we compute (eqn:3). The thing in the parenthesis is, CGC^{-0.5},
# can be written as U S^{-0.5} U^T, so the whole thing is
# (C U) S^{-0.5} (C U)^T.
#
# (C U) has good condition number because we smoothed the
# eigenvalues of P_prime; S^{-0.5} may not have good condition, in
# fact S could have zero eigenvalues (as a result of G having zero
# or very small elements). For numerical reasons, to avoid infinity
# in elements of S when we compute S^{-0.5}, we apply a floor here.
# Later, we will only care about the eigen-directions of Z_prime and
# not the eigenvalues, so aside from numerical consideratins we
# don't want to smooth S too aggressively here.
eps = 1.0e-12
S_floor = eps + (1.0e-06 * _mean(S, exclude_dims=[0], keepdim=True))
U, S, _ = _svd(CGC) # Need SVD to compute CGC^{0.5}
# (eqn:4) says: Z'^{-1} = C^{-T} (C^T G' C)^0.5 C^{-1}
# Since (C^T G' C)^0.5 = U S^{0.5} U^T,
# we have Z'^{-1} = X S^{0.5} X^T where X = C^{-T} U
X = torch.triangular_solve(U, C, upper=False, transpose=True)[0]
Z_prime_inv = torch.matmul(X * S.sqrt().unsqueeze(-2), X.transpose(2, 3))
S = _soft_floor(S, S_floor)
CU = torch.matmul(C, U)
S_inv_sqrt = 1.0 / S.sqrt()
Z_prime = torch.matmul(CU * S_inv_sqrt.unsqueeze(-2),
CU.transpose(2, 3))
if True:
def _check_similar(x, y, name):
ratio = (y-x).abs().sum() / x.abs().sum()
ratio = (y-x).abs().sum() / (x.abs().sum() + 1.0e-20)
if not (ratio < 0.0001):
logging.warn(f"Check {name} failed, ratio={ratio.item()}")
logging.warning(f"Check {name} failed, ratio={ratio.item()}, {x} vs. {y}")
def _check_symmetric(x, x_name):
diff = x - x.transpose(-2, -1)
ratio = diff.abs().sum() / x.abs().sum()
if not (ratio < 0.0001):
logging.warn(f"{x_name} is not symmetric: ratio={ratio.item()}")
logging.warning(f"{x_name} is not symmetric: ratio={ratio.item()}")
_check_similar(P_prime, torch.matmul(C, C.transpose(2, 3)), "cholesky_check")
_check_symmetric(Z_prime, "Z_prime")
_check_similar(torch.matmul(C.transpose(2, 3), X), U, "CTX")
_check_symmetric(Z_prime_inv, "Z_prime_inv")
_check_symmetric(P_prime, "P_prime")
# A check.
# Z_prime is supposed to satisfy: Z_prime G_prime Z_prime = P_prime (eqn:2)
P_prime_check = torch.matmul(Z_prime * G_prime.unsqueeze(-2), Z_prime)
diff_ratio_l2 = ((P_prime - P_prime_check) ** 2).sum().sqrt() // (P_prime ** 2).sum().sqrt()
diff_ratio_l1 = (P_prime - P_prime_check).abs().sum() // P_prime.abs().sum()
if not (diff_ratio_l2 < 0.00001) or diff_ratio_l1 > 0.03:
logging.warn(f"Z_prime does not satisfy its definition, diff_ratio_l{1,2} = {diff_ratio_l1.item(),diff_ratio_l2.item()}, size={size}")
# Z_prime is supposed to satisfy: Z_prime G_prime Z_prime = P_prime (eqn:2),
# or alternatively with the inverse,
# G_prime = Z_prime_inv P_prime Z_prime_inv
G_prime_check = _diag(torch.matmul(Z_prime_inv, torch.matmul(P_prime, Z_prime_inv)))
_check_similar(G_prime, G_prime_check, "G_prime")
# We really want the SVD on Z, which will be used for the learning-rate matrix
# Q, but Z_prime is better, numerically, to work on because it's closer to
# being diagonalized.
U_z_prime, S, _ = _svd(Z_prime)
U_z_prime, S_z_prime_inv, _ = _svd(Z_prime_inv)
U_z = torch.matmul(U_g, U_z_prime)
# We could obtain Z in two possible ways.
@ -970,7 +964,7 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
## U_z, S, _ = _svd(Z)
if True:
skip = 10 if S.shape[-1] > 40 else 1
logging.info(f"dim={dim}, G_prime is {G_prime[0,0,::skip]}, Eigs of Z are: {S[0,0,::skip]}")
logging.info(f"dim={dim}, G_prime is {G_prime[0,0,::skip]}, Eigs of Z_inv are: {S_z_prime_inv[0,0,::skip]}")
# state[f"Q_{dim}"] is indexed: [batch_idx, block_idx, diagonalized_coordinate, canonical_coordinate].
# so we need to transpose U_z as U_z is indexed
@ -1096,7 +1090,6 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
G_prime_mean = _mean(G_prime, exclude_dims=[0], keepdim=True)
G_prime_min = group["cov_min"][3]
# make sure G_prime has no zero eigs, and is unit mean.
#G_prime = _soft_floor(G_prime, G_prime_min * G_prime_mean, depth=3)
G_prime = G_prime + G_prime_min * G_prime_mean + 1.0e-20
G_prime /= _mean(G_prime, exclude_dims=[0], keepdim=True)
# it now has unit mean..
@ -1302,7 +1295,7 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
p: Tensor,
state: dict):
"""
This function does the core update of self._step, in the case where the tensor
This function does the core update of self.step(), in the case where the tensor
has more than 1 elements. It multiplies the moving-average gradient tensor by a
state["lr_{dim}"] on each non-trivial axis/dimension, does a length normalization,
and takes a step in that direction.
@ -1338,7 +1331,6 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
denom = exp_avg_sq.sqrt()
denom += eps + denom_rel_eps * _mean(denom, exclude_dims=[0], keepdim=True)
grad = grad / denom
# and project back..
@ -1525,40 +1517,6 @@ def _mean(x: Tensor,
def _soft_floor(x: Tensor, floor: Union[float, Tensor], depth: float = 3) -> Tensor:
"""
Applies a floor `floor` to x in a soft way such that values of x below the floor,
as low as floor**2, stay reasonably distinct. `floor` will be mapped to
3*floor,
sqrt(floor), 0 will be mapped to `floor`, floor**2 will be mapped to 2*floor,
and values above about sqrt(floor) will be approximately linear.
Args:
x : Tensor, a Tensor of any shape, to apply a floor to, will not be modified.
Its values must be >= 0.
floor: the floor to apply, must be >0. E.g. 1.0e-10
depth: the number of iterations in the algorithm; more means it will map
tiny values more aggressively to larger values. Recommend
not more than 5.
"""
terms = [x]
cur_pow = 1.0
for i in range(depth):
x = x.sqrt()
cur_pow *= 0.5
# the reason for dividing by `depth` is that regardless of the depth,
# we want `floor` itself to be mapped to 2*floor by the algorithm,
# so we want this sum to eventually equal `floor`, if x were equal
# to `floor`
terms.append(x * ((floor ** (1-cur_pow)) / depth))
ans = torch.stack(terms)
ans = ans.sum(dim=0)
return ans + floor
def _svd(x: Tensor):
# Wrapper for torch svd that catches errors and re-tries (to address bugs in
# some versions of torch SVD for CUDA)
@ -2309,16 +2267,6 @@ def _test_svd():
X2 = torch.matmul(U*S, V.t())
assert torch.allclose(X2, X, atol=0.001)
def _test_soft_floor():
x = torch.arange(100) - 50
y = x.to(torch.float32).exp()
z = _soft_floor(y, 1.0e-06)
print(f"y={y}, z={z}")
z = _soft_floor(torch.tensor(1.0e-06), 1.0e-06)
print("z = ", z)
assert (z - 3.0e-06).abs() < 1.0e-07
if __name__ == "__main__":
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
@ -2327,6 +2275,5 @@ if __name__ == "__main__":
s = subprocess.check_output("git status -uno .; git log -1", shell=True)
logging.info(s)
#_test_svd()
_test_soft_floor()
_test_eve_cain()
#_test_eden()