diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py index 6afdadbea..bc727424d 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py @@ -947,10 +947,22 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of diff_ratio_l1 = (P_prime - P_prime_check).abs().sum() // P_prime.abs().sum() if not (diff_ratio_l2 < 0.00001) or diff_ratio_l1 > 0.03: logging.warn(f"Z_prime does not satisfy its definition, diff_ratio_l{1,2} = {diff_ratio_l1.item(),diff_ratio_l2.item()}, size={size}") - Z = torch.matmul(U_g, torch.matmul(Z_prime, U_g.transpose(2, 3))) - # OK, Z is the SPD transform that maps G to P, as in Z G Z = P. - # We just need the basis U_z that diagonalizes this. - U_z, S, _ = _svd(Z) + + + # We really want the SVD on Z, which will be used for the learning-rate matrix + # Q, but Z_prime is better, numerically, to work on because it's closer to + # being diagonalized. + U_z_prime, S, _ = _svd(Z_prime) + + U_z = torch.matmul(U_g, U_z_prime) + # We could obtain Z in two possible ways. + # Commenting the check below. + # Z_a = torch.matmul(U_z * S.unsqueeze(-2), U_z.transpose(2, 3)) + # Z_b = torch.matmul(U_g, torch.matmul(Z_prime, U_g.transpose(2, 3))) + # _check_similar(Z_a, Z_b, "Z") + ## OK, Z is the SPD transform that maps G to P, as in Z G Z = P. + ## We just need the basis U_z that diagonalizes this. + ## U_z, S, _ = _svd(Z) if True: skip = 10 if S.shape[-1] > 40 else 1 logging.info(f"dim={dim}, G_prime is {G_prime[0,0,::skip]}, Eigs of Z are: {S[0,0,::skip]}")