From 8efc512823fb555b8f303f6cdb1ec4c65dc2df72 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 24 Jul 2022 11:52:10 +0800 Subject: [PATCH] Remove some debugging code, found the mismatch --- .../ASR/pruned_transducer_stateless7/optim.py | 15 ++------------- 1 file changed, 2 insertions(+), 13 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py index bd0435db5..b7f010a83 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py @@ -956,21 +956,10 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of if simple_update: # normalize the scales in a way that preserves the Frobenius norm of the # projected parameter deltas - P_rms = this_P_proj.clone() - P_rms = P_rms / _mean(P_rms, exclude_dims=[0], keepdim=True) - P_rms_compare = this_P_proj.clone() - P_rms_compare /= _mean(P_rms_compare, exclude_dims=[0], keepdim=True) - - mean_diff = (P_rms - P_rms_compare) - if True: - ratio = mean_diff.abs().sum() / P_rms.abs().sum() - logging.info(f"ratio for division is {ratio}, shapes are {P_rms.shape}, {P_rms_compare.shape}") - if ratio > 1.0e-10: - logging.warn(f"P_rms={P_rms}, P_rms_compare={P_rms_compare}") + P_rms = this_P_proj / _mean(this_P_proj, exclude_dims=[0], keepdim=True) scale = P_rms.unsqueeze(-1).sqrt() Q *= scale - logging.info(f"Q rms = {(Q**2).mean().sqrt()}, abs-rms = {Q.abs().mean()}") - + logging.info(f"Q rms = {(Q**2).mean().sqrt()} abs-rms = {Q.abs().mean()}") # no iterative stuff, just use sqrt(P_proj) as scale on Q. If this is False, we need to # call self._update_param_scales(...) from the calling function. if True: