Improve debugging output.

This commit is contained in:
Daniel Povey 2022-07-25 09:02:36 +08:00
parent 854c2965a9
commit fe595f8772

View File

@ -1001,9 +1001,9 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
# debug output
this_P_proj_unsmoothed = _diag(torch.matmul(U_prod, torch.matmul(P_prime_unsmoothed,
U_prod.transpose(2, 3))))
this_P_proj_unsmoothed = this_P_proj_unsmoothed.clone().reshape(batch_size, size)
skip = 10 if P_proj[dim].shape[-1] > 40 else 1
logging.info(f"dim={dim}, diag of P_proj is: {P_proj[dim][0,::skip]}, diag of unsmoothed P_proj is: {this_P_proj_unsmoothed[0,::skip]}")
G_proj_unsmoothed = _diag(torch.matmul(U_prod * G_prime.unsqueeze(-2), U_prod.transpose(2, 3)))
skip = 10 if block_size > 40 else 1
logging.info(f"dim={dim}, diag of P_proj is: {P_proj[dim][0,0:block_size:skip]}, diag of unsmoothed P_proj is: {this_P_proj_unsmoothed[0,0,::skip]}, diag of unsmoothed G_proj is {G_proj_unsmoothed[0,0,::skip]}")
# P_proj won't be needed if simple_update == True.
return P_proj