Reworking the computation of Z to be numerically better.

This commit is contained in:
Daniel Povey 2022-07-25 06:37:26 +08:00
parent 5513f7fee5
commit 3acdf3b395

View File

@ -947,10 +947,22 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
diff_ratio_l1 = (P_prime - P_prime_check).abs().sum() // P_prime.abs().sum() diff_ratio_l1 = (P_prime - P_prime_check).abs().sum() // P_prime.abs().sum()
if not (diff_ratio_l2 < 0.00001) or diff_ratio_l1 > 0.03: if not (diff_ratio_l2 < 0.00001) or diff_ratio_l1 > 0.03:
logging.warn(f"Z_prime does not satisfy its definition, diff_ratio_l{1,2} = {diff_ratio_l1.item(),diff_ratio_l2.item()}, size={size}") logging.warn(f"Z_prime does not satisfy its definition, diff_ratio_l{1,2} = {diff_ratio_l1.item(),diff_ratio_l2.item()}, size={size}")
Z = torch.matmul(U_g, torch.matmul(Z_prime, U_g.transpose(2, 3)))
# OK, Z is the SPD transform that maps G to P, as in Z G Z = P.
# We just need the basis U_z that diagonalizes this. # We really want the SVD on Z, which will be used for the learning-rate matrix
U_z, S, _ = _svd(Z) # Q, but Z_prime is better, numerically, to work on because it's closer to
# being diagonalized.
U_z_prime, S, _ = _svd(Z_prime)
U_z = torch.matmul(U_g, U_z_prime)
# We could obtain Z in two possible ways.
# Commenting the check below.
# Z_a = torch.matmul(U_z * S.unsqueeze(-2), U_z.transpose(2, 3))
# Z_b = torch.matmul(U_g, torch.matmul(Z_prime, U_g.transpose(2, 3)))
# _check_similar(Z_a, Z_b, "Z")
## OK, Z is the SPD transform that maps G to P, as in Z G Z = P.
## We just need the basis U_z that diagonalizes this.
## U_z, S, _ = _svd(Z)
if True: if True:
skip = 10 if S.shape[-1] > 40 else 1 skip = 10 if S.shape[-1] > 40 else 1
logging.info(f"dim={dim}, G_prime is {G_prime[0,0,::skip]}, Eigs of Z are: {S[0,0,::skip]}") logging.info(f"dim={dim}, G_prime is {G_prime[0,0,::skip]}, Eigs of Z are: {S[0,0,::skip]}")