From 2c4bdd0ad03a920682e4aab22ff240bd4f096032 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 23 Jul 2022 14:49:58 +0800 Subject: [PATCH] Add _update_param_scales_simple(), add documentation --- .../ASR/pruned_transducer_stateless7/optim.py | 55 +++++++++++++++---- 1 file changed, 44 insertions(+), 11 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py index 2875c795a..f6539e1f1 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py @@ -555,15 +555,42 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of return (step >= zero_step + cur_update_period) + def _update_param_scales_simple(self, + group: dict, + p: Tensor, + state: dict, + P_proj: List[Optional[Tensor]]) -> None: + for dim in range(1, p.ndim): + size = p.shape[dim] + try: + Q = state[f"Q_{dim}"] + except KeyError: + assert size == 1 or size == numel, size + continue # e.g. size == 1 or size == numel + + (batch_size, num_blocks, block_size, block_size) = Q.shape + this_P_proj = P_proj[dim].reshape(batch_size, num_blocks, block_size, 1) + + # The following normalization step will ensure the Frobenius + # norm is unchanged, from applying this scale: at least, + # assuming "grad / denom" gives uncorrelated outputs so that + # they will have equal variances after projecting to the space + # where the parameter var is diagonalized... this is *roughly* + # true because the gradients at the point where we compute "grad + # / denom" should be decorrelated at least considering + # individual tensor dims + this_P_proj /= _mean(this_P_proj, exclude_dims=[0], keepdim=True) + + Q *= this_P_proj.sqrt() + def _update_param_scales(self, group: dict, p: Tensor, state: dict, P_proj: List[Optional[Tensor]]) -> None: """ - Computes learning-rate matrices Q for each dim of this tensor: only the part that depends - on the parameter covariance, we will later add a rotation that depends on the gradient - covariance. + Modifies the scales on the rows of the learning-rate matrices Q for each dim of this tensor, + to take into account the estimated parameter covariance. Args: group: dict to look up configuration values @@ -636,13 +663,19 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of S = smoothed_param_var.reshape(cur_param_var.shape) # (batch_size, 1, size, 1, 1) if dim==2 # OK, cur_param_var would have the values as S if the variance stats - # param_cov_{dim} were accumulated from this exact parameter matrix, - # but actually they contain older versions of the parameter - # covariance so they will, in general, be less extreme ("whiter - # spectrum"). We scale p so that it matches the accumulated stats, - # the idea is to ensure it doesn't have any too-small eigenvalues - # (where the stats permit). - + # P_proj[dim] were accumulated from this exact parameter matrix and + # not smoothed, but actually they contain older versions of the + # parameter covariance and they have been smoothed, so they will, in + # general, be less extreme ("whiter spectrum"). We scale p so that + # it matches the estimated variance P_proj[dim]; the idea is to ensure it doesn't + # have too-extreme eigenvalues (where the stats permit). + # Actually we could just use P_proj[dim].sqrt(), suitably scaled, + # as the scales on the rows of Q (see _update_param_scales_simple() which does + # exactly this), but there is a problem of "counting things twice" + # which is easiest to understand for a 2-dimensional tensor, where the + # singular values show up identically in the covariance over either axis. + # The estimation procedure in this function avoids the "counting things twice" + # problem, at the expense of quite a bit of extra complexity. scale = (S / cur_param_var.clamp(min=eps)).sqrt() if True: @@ -650,7 +683,7 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of cur_tmp = cur_param_var.reshape(batch_size, size) scale_tmp = scale.reshape(batch_size, size) skip = 10 if size > 40 else 1 - logging.info(f"cur_param_var={cur_tmp[0][::skip]}, S={S_tmp[0][::skip]}, scale={scale_tmp[0][::skip]}") + logging.info(f"dim={dim}/{ndim}, cur_param_var={cur_tmp[0][::skip]}, S={S_tmp[0][::skip]}, scale={scale_tmp[0][::skip]}") if random.random() < 0.01: