mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-19 05:54:20 +00:00
Saving version I am trying to debug
This commit is contained in:
parent
962e95f119
commit
ba96439c76
@ -419,7 +419,7 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
|||||||
# rate matrices at most every other time we reach here, and
|
# rate matrices at most every other time we reach here, and
|
||||||
# less frequently than that later in training.
|
# less frequently than that later in training.
|
||||||
#self._update_param_scales(group, p, state, P_proj)
|
#self._update_param_scales(group, p, state, P_proj)
|
||||||
self._update_param_scales_simple(group, p, state, P_proj)
|
#self._update_param_scales_simple(group, p, state, P_proj)
|
||||||
|
|
||||||
# We won't be doing this any more.
|
# We won't be doing this any more.
|
||||||
#self._diagonalize_grad_cov(group, p, state)
|
#self._diagonalize_grad_cov(group, p, state)
|
||||||
@ -608,6 +608,7 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
|||||||
logging.info(f"step={step}, dim={dim}, size={size}, scale min,max,mean={scale_min,scale_max,scale_mean}")
|
logging.info(f"step={step}, dim={dim}, size={size}, scale min,max,mean={scale_min,scale_max,scale_mean}")
|
||||||
|
|
||||||
Q *= this_P_proj.sqrt()
|
Q *= this_P_proj.sqrt()
|
||||||
|
logging.info(f"Q rms = {(Q**2).mean().sqrt()}, abs-rms = {Q.abs().mean()}")
|
||||||
|
|
||||||
def _update_param_scales(self,
|
def _update_param_scales(self,
|
||||||
group: dict,
|
group: dict,
|
||||||
@ -950,13 +951,42 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of
|
|||||||
# this_P_proj shape: (batch_size, num_blocks, block_size)
|
# this_P_proj shape: (batch_size, num_blocks, block_size)
|
||||||
this_P_proj = _diag(torch.matmul(U_prod, torch.matmul(P_prime, U_prod.transpose(2, 3))))
|
this_P_proj = _diag(torch.matmul(U_prod, torch.matmul(P_prime, U_prod.transpose(2, 3))))
|
||||||
P_proj[dim] = this_P_proj.clone().reshape(batch_size, size)
|
P_proj[dim] = this_P_proj.clone().reshape(batch_size, size)
|
||||||
|
|
||||||
|
simple_update = True
|
||||||
|
if simple_update:
|
||||||
|
# normalize the scales in a way that preserves the Frobenius norm of the
|
||||||
|
# projected parameter deltas
|
||||||
|
P_rms = this_P_proj.clone()
|
||||||
|
P_rms = P_rms / _mean(P_rms, exclude_dims=[0], keepdim=True)
|
||||||
|
P_rms_compare = this_P_proj.clone()
|
||||||
|
P_rms_compare /= _mean(P_rms_compare, exclude_dims=[0], keepdim=True)
|
||||||
|
|
||||||
|
mean_diff = (P_rms - P_rms_compare)
|
||||||
|
if True:
|
||||||
|
ratio = mean_diff.abs().sum() / P_rms.abs().sum()
|
||||||
|
logging.info(f"ratio for division is {ratio}, shapes are {P_rms.shape}, {P_rms_compare.shape}")
|
||||||
|
if ratio > 1.0e-10:
|
||||||
|
logging.warn(f"P_rms={P_rms}, P_rms_compare={P_rms_compare}")
|
||||||
|
scale = P_rms.unsqueeze(-1).sqrt()
|
||||||
|
Q *= scale
|
||||||
|
logging.info(f"Q rms = {(Q**2).mean().sqrt()}, abs-rms = {Q.abs().mean()}")
|
||||||
|
|
||||||
|
# no iterative stuff, just use sqrt(P_proj) as scale on Q. If this is False, we need to
|
||||||
|
# call self._update_param_scales(...) from the calling function.
|
||||||
|
if True:
|
||||||
|
# debug output
|
||||||
|
step = state["step"]
|
||||||
|
scale_min, scale_max, scale_mean = scale.min().item(), scale.max().item(), scale.mean().item()
|
||||||
|
logging.info(f"step={step}, dim={dim}, size={size}, scale min,max,mean={scale_min,scale_max,scale_mean}")
|
||||||
if True:
|
if True:
|
||||||
|
# debug output
|
||||||
this_P_proj_unsmoothed = _diag(torch.matmul(U_prod, torch.matmul(P_prime_unsmoothed,
|
this_P_proj_unsmoothed = _diag(torch.matmul(U_prod, torch.matmul(P_prime_unsmoothed,
|
||||||
U_prod.transpose(2, 3))))
|
U_prod.transpose(2, 3))))
|
||||||
this_P_proj_unsmoothed = this_P_proj_unsmoothed.clone().reshape(batch_size, size)
|
this_P_proj_unsmoothed = this_P_proj_unsmoothed.clone().reshape(batch_size, size)
|
||||||
skip = 10 if P_proj[dim].shape[-1] > 40 else 1
|
skip = 10 if P_proj[dim].shape[-1] > 40 else 1
|
||||||
logging.info(f"dim={dim}, diag of P_proj is: {P_proj[dim][0,::skip]}, diag of unsmoothed P_proj is: {this_P_proj_unsmoothed[0,::skip]}")
|
logging.info(f"dim={dim}, diag of P_proj is: {P_proj[dim][0,::skip]}, diag of unsmoothed P_proj is: {this_P_proj_unsmoothed[0,::skip]}")
|
||||||
|
|
||||||
|
# P_proj won't be needed if simple_update == True.
|
||||||
return P_proj
|
return P_proj
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user