mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Rework computation to reduce numerical roundoff
This commit is contained in:
parent
633cbd551a
commit
3c1fddaf48
@ -858,6 +858,12 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
||||
(C^T G' C)^{-0.5} C^T G' C (C^T G' C)^{-0.5} = I
|
||||
which we can immediately see is the case. ]]
|
||||
|
||||
It's actually going to be more convenient to compute Z'^{1} (the inverse of Z'), because
|
||||
ultimately only need the basis that diagonalizes Z, and (C^T G' C) may be singular.
|
||||
It's not hard to work out that:
|
||||
Z'^{-1} = C^{-T} (C^T G' C)^0.5 C^{-1} (eqn:4)
|
||||
|
||||
|
||||
Args:
|
||||
group: dict to look up config values
|
||||
p_shape: the shape of the batch of identical-sized tensors we are optimizing
|
||||
@ -908,56 +914,44 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
||||
# if G were formatted as a diagonal matrix, but G is just the diagonal.
|
||||
CGC = torch.matmul(C.transpose(2, 3) * G_prime.unsqueeze(-2), C)
|
||||
|
||||
U, S, _ = _svd(CGC) # Need SVD to compute CGC^{-0.5}
|
||||
# next we compute (eqn:3). The thing in the parenthesis is, CGC^{-0.5},
|
||||
# can be written as U S^{-0.5} U^T, so the whole thing is
|
||||
# (C U) S^{-0.5} (C U)^T.
|
||||
#
|
||||
# (C U) has good condition number because we smoothed the
|
||||
# eigenvalues of P_prime; S^{-0.5} may not have good condition, in
|
||||
# fact S could have zero eigenvalues (as a result of G having zero
|
||||
# or very small elements). For numerical reasons, to avoid infinity
|
||||
# in elements of S when we compute S^{-0.5}, we apply a floor here.
|
||||
# Later, we will only care about the eigen-directions of Z_prime and
|
||||
# not the eigenvalues, so aside from numerical consideratins we
|
||||
# don't want to smooth S too aggressively here.
|
||||
eps = 1.0e-12
|
||||
S_floor = eps + (1.0e-06 * _mean(S, exclude_dims=[0], keepdim=True))
|
||||
U, S, _ = _svd(CGC) # Need SVD to compute CGC^{0.5}
|
||||
|
||||
# (eqn:4) says: Z'^{-1} = C^{-T} (C^T G' C)^0.5 C^{-1}
|
||||
# Since (C^T G' C)^0.5 = U S^{0.5} U^T,
|
||||
# we have Z'^{-1} = X S^{0.5} X^T where X = C^{-T} U
|
||||
X = torch.triangular_solve(U, C, upper=False, transpose=True)[0]
|
||||
Z_prime_inv = torch.matmul(X * S.sqrt().unsqueeze(-2), X.transpose(2, 3))
|
||||
|
||||
S = _soft_floor(S, S_floor)
|
||||
CU = torch.matmul(C, U)
|
||||
S_inv_sqrt = 1.0 / S.sqrt()
|
||||
Z_prime = torch.matmul(CU * S_inv_sqrt.unsqueeze(-2),
|
||||
CU.transpose(2, 3))
|
||||
|
||||
if True:
|
||||
def _check_similar(x, y, name):
|
||||
ratio = (y-x).abs().sum() / x.abs().sum()
|
||||
ratio = (y-x).abs().sum() / (x.abs().sum() + 1.0e-20)
|
||||
if not (ratio < 0.0001):
|
||||
logging.warn(f"Check {name} failed, ratio={ratio.item()}")
|
||||
logging.warning(f"Check {name} failed, ratio={ratio.item()}, {x} vs. {y}")
|
||||
|
||||
def _check_symmetric(x, x_name):
|
||||
diff = x - x.transpose(-2, -1)
|
||||
ratio = diff.abs().sum() / x.abs().sum()
|
||||
if not (ratio < 0.0001):
|
||||
logging.warn(f"{x_name} is not symmetric: ratio={ratio.item()}")
|
||||
logging.warning(f"{x_name} is not symmetric: ratio={ratio.item()}")
|
||||
|
||||
_check_similar(P_prime, torch.matmul(C, C.transpose(2, 3)), "cholesky_check")
|
||||
_check_symmetric(Z_prime, "Z_prime")
|
||||
_check_similar(torch.matmul(C.transpose(2, 3), X), U, "CTX")
|
||||
_check_symmetric(Z_prime_inv, "Z_prime_inv")
|
||||
_check_symmetric(P_prime, "P_prime")
|
||||
# A check.
|
||||
# Z_prime is supposed to satisfy: Z_prime G_prime Z_prime = P_prime (eqn:2)
|
||||
P_prime_check = torch.matmul(Z_prime * G_prime.unsqueeze(-2), Z_prime)
|
||||
diff_ratio_l2 = ((P_prime - P_prime_check) ** 2).sum().sqrt() // (P_prime ** 2).sum().sqrt()
|
||||
diff_ratio_l1 = (P_prime - P_prime_check).abs().sum() // P_prime.abs().sum()
|
||||
if not (diff_ratio_l2 < 0.00001) or diff_ratio_l1 > 0.03:
|
||||
logging.warn(f"Z_prime does not satisfy its definition, diff_ratio_l{1,2} = {diff_ratio_l1.item(),diff_ratio_l2.item()}, size={size}")
|
||||
# Z_prime is supposed to satisfy: Z_prime G_prime Z_prime = P_prime (eqn:2),
|
||||
# or alternatively with the inverse,
|
||||
# G_prime = Z_prime_inv P_prime Z_prime_inv
|
||||
G_prime_check = _diag(torch.matmul(Z_prime_inv, torch.matmul(P_prime, Z_prime_inv)))
|
||||
_check_similar(G_prime, G_prime_check, "G_prime")
|
||||
|
||||
|
||||
|
||||
# We really want the SVD on Z, which will be used for the learning-rate matrix
|
||||
# Q, but Z_prime is better, numerically, to work on because it's closer to
|
||||
# being diagonalized.
|
||||
U_z_prime, S, _ = _svd(Z_prime)
|
||||
U_z_prime, S_z_prime_inv, _ = _svd(Z_prime_inv)
|
||||
|
||||
U_z = torch.matmul(U_g, U_z_prime)
|
||||
# We could obtain Z in two possible ways.
|
||||
@ -970,7 +964,7 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
||||
## U_z, S, _ = _svd(Z)
|
||||
if True:
|
||||
skip = 10 if S.shape[-1] > 40 else 1
|
||||
logging.info(f"dim={dim}, G_prime is {G_prime[0,0,::skip]}, Eigs of Z are: {S[0,0,::skip]}")
|
||||
logging.info(f"dim={dim}, G_prime is {G_prime[0,0,::skip]}, Eigs of Z_inv are: {S_z_prime_inv[0,0,::skip]}")
|
||||
|
||||
# state[f"Q_{dim}"] is indexed: [batch_idx, block_idx, diagonalized_coordinate, canonical_coordinate].
|
||||
# so we need to transpose U_z as U_z is indexed
|
||||
@ -1096,7 +1090,6 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
||||
G_prime_mean = _mean(G_prime, exclude_dims=[0], keepdim=True)
|
||||
G_prime_min = group["cov_min"][3]
|
||||
# make sure G_prime has no zero eigs, and is unit mean.
|
||||
#G_prime = _soft_floor(G_prime, G_prime_min * G_prime_mean, depth=3)
|
||||
G_prime = G_prime + G_prime_min * G_prime_mean + 1.0e-20
|
||||
G_prime /= _mean(G_prime, exclude_dims=[0], keepdim=True)
|
||||
# it now has unit mean..
|
||||
@ -1302,7 +1295,7 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
||||
p: Tensor,
|
||||
state: dict):
|
||||
"""
|
||||
This function does the core update of self._step, in the case where the tensor
|
||||
This function does the core update of self.step(), in the case where the tensor
|
||||
has more than 1 elements. It multiplies the moving-average gradient tensor by a
|
||||
state["lr_{dim}"] on each non-trivial axis/dimension, does a length normalization,
|
||||
and takes a step in that direction.
|
||||
@ -1338,7 +1331,6 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
||||
|
||||
denom = exp_avg_sq.sqrt()
|
||||
denom += eps + denom_rel_eps * _mean(denom, exclude_dims=[0], keepdim=True)
|
||||
|
||||
grad = grad / denom
|
||||
|
||||
# and project back..
|
||||
@ -1525,40 +1517,6 @@ def _mean(x: Tensor,
|
||||
|
||||
|
||||
|
||||
|
||||
def _soft_floor(x: Tensor, floor: Union[float, Tensor], depth: float = 3) -> Tensor:
|
||||
"""
|
||||
Applies a floor `floor` to x in a soft way such that values of x below the floor,
|
||||
as low as floor**2, stay reasonably distinct. `floor` will be mapped to
|
||||
3*floor,
|
||||
|
||||
sqrt(floor), 0 will be mapped to `floor`, floor**2 will be mapped to 2*floor,
|
||||
and values above about sqrt(floor) will be approximately linear.
|
||||
|
||||
Args:
|
||||
x : Tensor, a Tensor of any shape, to apply a floor to, will not be modified.
|
||||
Its values must be >= 0.
|
||||
floor: the floor to apply, must be >0. E.g. 1.0e-10
|
||||
depth: the number of iterations in the algorithm; more means it will map
|
||||
tiny values more aggressively to larger values. Recommend
|
||||
not more than 5.
|
||||
"""
|
||||
terms = [x]
|
||||
cur_pow = 1.0
|
||||
for i in range(depth):
|
||||
x = x.sqrt()
|
||||
cur_pow *= 0.5
|
||||
# the reason for dividing by `depth` is that regardless of the depth,
|
||||
# we want `floor` itself to be mapped to 2*floor by the algorithm,
|
||||
# so we want this sum to eventually equal `floor`, if x were equal
|
||||
# to `floor`
|
||||
terms.append(x * ((floor ** (1-cur_pow)) / depth))
|
||||
ans = torch.stack(terms)
|
||||
ans = ans.sum(dim=0)
|
||||
return ans + floor
|
||||
|
||||
|
||||
|
||||
def _svd(x: Tensor):
|
||||
# Wrapper for torch svd that catches errors and re-tries (to address bugs in
|
||||
# some versions of torch SVD for CUDA)
|
||||
@ -2309,16 +2267,6 @@ def _test_svd():
|
||||
X2 = torch.matmul(U*S, V.t())
|
||||
assert torch.allclose(X2, X, atol=0.001)
|
||||
|
||||
def _test_soft_floor():
|
||||
x = torch.arange(100) - 50
|
||||
y = x.to(torch.float32).exp()
|
||||
z = _soft_floor(y, 1.0e-06)
|
||||
print(f"y={y}, z={z}")
|
||||
|
||||
z = _soft_floor(torch.tensor(1.0e-06), 1.0e-06)
|
||||
print("z = ", z)
|
||||
assert (z - 3.0e-06).abs() < 1.0e-07
|
||||
|
||||
if __name__ == "__main__":
|
||||
torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
@ -2327,6 +2275,5 @@ if __name__ == "__main__":
|
||||
s = subprocess.check_output("git status -uno .; git log -1", shell=True)
|
||||
logging.info(s)
|
||||
#_test_svd()
|
||||
_test_soft_floor()
|
||||
_test_eve_cain()
|
||||
#_test_eden()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user