diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py index a8a0a05b0..bd0435db5 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py @@ -419,7 +419,7 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of # rate matrices at most every other time we reach here, and # less frequently than that later in training. #self._update_param_scales(group, p, state, P_proj) - self._update_param_scales_simple(group, p, state, P_proj) + #self._update_param_scales_simple(group, p, state, P_proj) # We won't be doing this any more. #self._diagonalize_grad_cov(group, p, state) @@ -608,6 +608,7 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of logging.info(f"step={step}, dim={dim}, size={size}, scale min,max,mean={scale_min,scale_max,scale_mean}") Q *= this_P_proj.sqrt() + logging.info(f"Q rms = {(Q**2).mean().sqrt()}, abs-rms = {Q.abs().mean()}") def _update_param_scales(self, group: dict, @@ -950,13 +951,42 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of # this_P_proj shape: (batch_size, num_blocks, block_size) this_P_proj = _diag(torch.matmul(U_prod, torch.matmul(P_prime, U_prod.transpose(2, 3)))) P_proj[dim] = this_P_proj.clone().reshape(batch_size, size) + + simple_update = True + if simple_update: + # normalize the scales in a way that preserves the Frobenius norm of the + # projected parameter deltas + P_rms = this_P_proj.clone() + P_rms = P_rms / _mean(P_rms, exclude_dims=[0], keepdim=True) + P_rms_compare = this_P_proj.clone() + P_rms_compare /= _mean(P_rms_compare, exclude_dims=[0], keepdim=True) + + mean_diff = (P_rms - P_rms_compare) + if True: + ratio = mean_diff.abs().sum() / P_rms.abs().sum() + logging.info(f"ratio for division is {ratio}, shapes are {P_rms.shape}, {P_rms_compare.shape}") + if ratio > 1.0e-10: + logging.warn(f"P_rms={P_rms}, P_rms_compare={P_rms_compare}") + scale = P_rms.unsqueeze(-1).sqrt() + Q *= scale + logging.info(f"Q rms = {(Q**2).mean().sqrt()}, abs-rms = {Q.abs().mean()}") + + # no iterative stuff, just use sqrt(P_proj) as scale on Q. If this is False, we need to + # call self._update_param_scales(...) from the calling function. + if True: + # debug output + step = state["step"] + scale_min, scale_max, scale_mean = scale.min().item(), scale.max().item(), scale.mean().item() + logging.info(f"step={step}, dim={dim}, size={size}, scale min,max,mean={scale_min,scale_max,scale_mean}") if True: + # debug output this_P_proj_unsmoothed = _diag(torch.matmul(U_prod, torch.matmul(P_prime_unsmoothed, U_prod.transpose(2, 3)))) this_P_proj_unsmoothed = this_P_proj_unsmoothed.clone().reshape(batch_size, size) skip = 10 if P_proj[dim].shape[-1] > 40 else 1 logging.info(f"dim={dim}, diag of P_proj is: {P_proj[dim][0,::skip]}, diag of unsmoothed P_proj is: {this_P_proj_unsmoothed[0,::skip]}") + # P_proj won't be needed if simple_update == True. return P_proj