diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py index 4146ec4b4..83ce90a92 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py @@ -166,7 +166,7 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of lr=3e-02, betas=(0.9, 0.98), size_lr_scale=0.1, - cov_min=(0.025, 0.0025, 0.02, 0.0001, 0.1), + cov_min=(0.025, 0.0025, 0.02, 0.0001, 0.025), cov_max=(10.0, 80.0, 5.0, 400.0, 100.0), cov_pow=(1.0, 1.0, 1.0, 1.0), param_rms_smooth0=0.4, @@ -969,30 +969,24 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of _check_similar(G_prime, G_prime_check, "G_prime") - Z_prime_inv_diag = _diag(Z_prime_inv) # aliased with Z_prime_inv + Z_inv = torch.matmul(U_g, torch.matmul(Z_prime_inv, U_g.transpose(2, 3))) + Z_inv = 0.5 * (Z_inv + Z_inv.transpose(2, 3)) # make sure exactly symmetric + + Z_inv_diag = _diag(Z_inv) # aliased with Z_inv # this is smoothing Z relative to its own diagonal. This is z_inv, # so by applying a minimum here, we are applying a maximum of the # eigs of Z after normalizing so the diagonal is 1. - Z_prime_inv_diag *= (1. + group["cov_min"][4]) + Z_inv_diag *= (1. + group["cov_min"][4]) # We really want the SVD on Z, which will be used for the learning-rate matrix # Q, but Z_prime is better, numerically, to work on because it's closer to # being diagonalized. - U_z_prime, S_z_prime_inv, _ = _svd(Z_prime_inv) + U_z, S_z_inv, _ = _svd(Z_inv) - U_z = torch.matmul(U_g, U_z_prime) - # We could obtain Z in two possible ways. - # Commenting the check below. - # Z_a = torch.matmul(U_z * S.unsqueeze(-2), U_z.transpose(2, 3)) - # Z_b = torch.matmul(U_g, torch.matmul(Z_prime, U_g.transpose(2, 3))) - # _check_similar(Z_a, Z_b, "Z") - ## OK, Z is the SPD transform that maps G to P, as in Z G Z = P. - ## We just need the basis U_z that diagonalizes this. - ## U_z, S, _ = _svd(Z) if True: skip = 10 if S.shape[-1] > 40 else 1 - logging.info(f"dim={dim}, G_prime is {G_prime[0,0,::skip]}, Eigs of Z_inv are: {S_z_prime_inv[0,0,::skip]}") + logging.info(f"dim={dim}, G_prime is {G_prime[0,0,::skip]}, Eigs of Z_inv are: {S_z_inv[0,0,::skip]}") # state[f"Q_{dim}"] is indexed: [batch_idx, block_idx, diagonalized_coordinate, canonical_coordinate]. # so we need to transpose U_z as U_z is indexed