mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-19 05:54:20 +00:00
Initial version of fixing numerical issue, will continue though
This commit is contained in:
parent
b0f0c6c4ab
commit
5513f7fee5
@ -906,7 +906,20 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
|||||||
U, S, _ = _svd(CGC) # Need SVD to compute CGC^{-0.5}
|
U, S, _ = _svd(CGC) # Need SVD to compute CGC^{-0.5}
|
||||||
# next we compute (eqn:3). The thing in the parenthesis is, CGC^{-0.5},
|
# next we compute (eqn:3). The thing in the parenthesis is, CGC^{-0.5},
|
||||||
# can be written as U S^{-0.5} U^T, so the whole thing is
|
# can be written as U S^{-0.5} U^T, so the whole thing is
|
||||||
# (C U) S^{-0.5} (C U)^T
|
# (C U) S^{-0.5} (C U)^T.
|
||||||
|
#
|
||||||
|
# (C U) has good condition number because we smoothed the
|
||||||
|
# eigenvalues of P_prime; S^{-0.5} may not have good condition, in
|
||||||
|
# fact S could have zero eigenvalues (as a result of G having zero
|
||||||
|
# or very small elements). For numerical reasons, to avoid infinity
|
||||||
|
# in elements of S when we compute S^{-0.5}, we apply a floor here.
|
||||||
|
# Later, we will only care about the eigen-directions of Z_prime and
|
||||||
|
# not the eigenvalues, so aside from numerical consideratins we
|
||||||
|
# don't want to smooth S too aggressively here.
|
||||||
|
eps = 1.0e-12
|
||||||
|
S_floor = eps + (1.0e-06 * _mean(S, exclude_dims=[0], keepdim=True))
|
||||||
|
|
||||||
|
S = _soft_floor(S, S_floor)
|
||||||
CU = torch.matmul(C, U)
|
CU = torch.matmul(C, U)
|
||||||
S_inv_sqrt = 1.0 / S.sqrt()
|
S_inv_sqrt = 1.0 / S.sqrt()
|
||||||
Z_prime = torch.matmul(CU * S_inv_sqrt.unsqueeze(-2),
|
Z_prime = torch.matmul(CU * S_inv_sqrt.unsqueeze(-2),
|
||||||
@ -915,13 +928,13 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
|||||||
if True:
|
if True:
|
||||||
def _check_similar(x, y, name):
|
def _check_similar(x, y, name):
|
||||||
ratio = (y-x).abs().sum() / x.abs().sum()
|
ratio = (y-x).abs().sum() / x.abs().sum()
|
||||||
if ratio > 0.0001:
|
if not (ratio < 0.0001):
|
||||||
logging.warn(f"Check {name} failed, ratio={ratio.item()}")
|
logging.warn(f"Check {name} failed, ratio={ratio.item()}")
|
||||||
|
|
||||||
def _check_symmetric(x, x_name):
|
def _check_symmetric(x, x_name):
|
||||||
diff = x - x.transpose(-2, -1)
|
diff = x - x.transpose(-2, -1)
|
||||||
ratio = diff.abs().sum() / x.abs().sum()
|
ratio = diff.abs().sum() / x.abs().sum()
|
||||||
if ratio > 0.0001:
|
if not (ratio < 0.0001):
|
||||||
logging.warn(f"{x_name} is not symmetric: ratio={ratio.item()}")
|
logging.warn(f"{x_name} is not symmetric: ratio={ratio.item()}")
|
||||||
|
|
||||||
_check_similar(P_prime, torch.matmul(C, C.transpose(2, 3)), "cholesky_check")
|
_check_similar(P_prime, torch.matmul(C, C.transpose(2, 3)), "cholesky_check")
|
||||||
@ -932,7 +945,7 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
|||||||
P_prime_check = torch.matmul(Z_prime * G_prime.unsqueeze(-2), Z_prime)
|
P_prime_check = torch.matmul(Z_prime * G_prime.unsqueeze(-2), Z_prime)
|
||||||
diff_ratio_l2 = ((P_prime - P_prime_check) ** 2).sum().sqrt() // (P_prime ** 2).sum().sqrt()
|
diff_ratio_l2 = ((P_prime - P_prime_check) ** 2).sum().sqrt() // (P_prime ** 2).sum().sqrt()
|
||||||
diff_ratio_l1 = (P_prime - P_prime_check).abs().sum() // P_prime.abs().sum()
|
diff_ratio_l1 = (P_prime - P_prime_check).abs().sum() // P_prime.abs().sum()
|
||||||
if diff_ratio_l2 > 0.00001 or diff_ratio_l1 > 0.03:
|
if not (diff_ratio_l2 < 0.00001) or diff_ratio_l1 > 0.03:
|
||||||
logging.warn(f"Z_prime does not satisfy its definition, diff_ratio_l{1,2} = {diff_ratio_l1.item(),diff_ratio_l2.item()}, size={size}")
|
logging.warn(f"Z_prime does not satisfy its definition, diff_ratio_l{1,2} = {diff_ratio_l1.item(),diff_ratio_l2.item()}, size={size}")
|
||||||
Z = torch.matmul(U_g, torch.matmul(Z_prime, U_g.transpose(2, 3)))
|
Z = torch.matmul(U_g, torch.matmul(Z_prime, U_g.transpose(2, 3)))
|
||||||
# OK, Z is the SPD transform that maps G to P, as in Z G Z = P.
|
# OK, Z is the SPD transform that maps G to P, as in Z G Z = P.
|
||||||
@ -1065,8 +1078,9 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
|||||||
G_prime_mean = _mean(G_prime, exclude_dims=[0], keepdim=True)
|
G_prime_mean = _mean(G_prime, exclude_dims=[0], keepdim=True)
|
||||||
G_prime_min = group["cov_min"][3]
|
G_prime_min = group["cov_min"][3]
|
||||||
# make sure G_prime has no zero eigs, and is unit mean.
|
# make sure G_prime has no zero eigs, and is unit mean.
|
||||||
G_prime = ((G_prime + eps + G_prime_min * G_prime_mean) /
|
#G_prime = _soft_floor(G_prime, G_prime_min * G_prime_mean, depth=3)
|
||||||
(G_prime_mean * (1+G_prime_min) + eps))
|
G_prime += G_prime_min * G_prime_mean
|
||||||
|
G_prime /= _mean(G_prime, exclude_dims=[0], keepdim=True)
|
||||||
# it now has unit mean..
|
# it now has unit mean..
|
||||||
G_prime_max = group["cov_max"][3]
|
G_prime_max = group["cov_max"][3]
|
||||||
G_prime = 1. / (1./G_prime + 1./G_prime_max) # apply max
|
G_prime = 1. / (1./G_prime + 1./G_prime_max) # apply max
|
||||||
@ -1491,6 +1505,38 @@ def _mean(x: Tensor,
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def _soft_floor(x: Tensor, floor: Union[float, Tensor], depth: float = 3) -> Tensor:
|
||||||
|
"""
|
||||||
|
Applies a floor `floor` to x in a soft way such that values of x below the floor,
|
||||||
|
as low as floor**2, stay reasonably distinct. `floor` will be mapped to
|
||||||
|
3*floor,
|
||||||
|
|
||||||
|
sqrt(floor), 0 will be mapped to `floor`, floor**2 will be mapped to 2*floor,
|
||||||
|
and values above about sqrt(floor) will be approximately linear.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x : Tensor, a Tensor of any shape, to apply a floor to, will not be modified.
|
||||||
|
Its values must be >= 0.
|
||||||
|
floor: the floor to apply, must be >0. E.g. 1.0e-10
|
||||||
|
depth: the number of iterations in the algorithm; more means it will map
|
||||||
|
tiny values more aggressively to larger values. Recommend
|
||||||
|
not more than 5.
|
||||||
|
"""
|
||||||
|
terms = [x]
|
||||||
|
cur_pow = 1.0
|
||||||
|
for i in range(depth):
|
||||||
|
x = x.sqrt()
|
||||||
|
cur_pow *= 0.5
|
||||||
|
# the reason for dividing by `depth` is that regardless of the depth,
|
||||||
|
# we want `floor` itself to be mapped to 2*floor by the algorithm,
|
||||||
|
# so we want this sum to eventually equal `floor`, if x were equal
|
||||||
|
# to `floor`
|
||||||
|
terms.append(x * ((floor ** (1-cur_pow)) / depth))
|
||||||
|
ans = torch.stack(terms)
|
||||||
|
ans = ans.sum(dim=0)
|
||||||
|
return ans + floor
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def _svd(x: Tensor):
|
def _svd(x: Tensor):
|
||||||
# Wrapper for torch svd that catches errors and re-tries (to address bugs in
|
# Wrapper for torch svd that catches errors and re-tries (to address bugs in
|
||||||
@ -2242,6 +2288,16 @@ def _test_svd():
|
|||||||
X2 = torch.matmul(U*S, V.t())
|
X2 = torch.matmul(U*S, V.t())
|
||||||
assert torch.allclose(X2, X, atol=0.001)
|
assert torch.allclose(X2, X, atol=0.001)
|
||||||
|
|
||||||
|
def _test_soft_floor():
|
||||||
|
x = torch.arange(100) - 50
|
||||||
|
y = x.to(torch.float32).exp()
|
||||||
|
z = _soft_floor(y, 1.0e-06)
|
||||||
|
print(f"y={y}, z={z}")
|
||||||
|
|
||||||
|
z = _soft_floor(torch.tensor(1.0e-06), 1.0e-06)
|
||||||
|
print("z = ", z)
|
||||||
|
assert (z - 3.0e-06).abs() < 1.0e-07
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
torch.set_num_threads(1)
|
torch.set_num_threads(1)
|
||||||
torch.set_num_interop_threads(1)
|
torch.set_num_interop_threads(1)
|
||||||
@ -2250,5 +2306,6 @@ if __name__ == "__main__":
|
|||||||
s = subprocess.check_output("git status -uno .; git log -1", shell=True)
|
s = subprocess.check_output("git status -uno .; git log -1", shell=True)
|
||||||
logging.info(s)
|
logging.info(s)
|
||||||
#_test_svd()
|
#_test_svd()
|
||||||
|
_test_soft_floor()
|
||||||
_test_eve_cain()
|
_test_eve_cain()
|
||||||
#_test_eden()
|
#_test_eden()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user