From 41df04577374619bf8d6f281e53b0c8ca5664daf Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 11 Jul 2022 17:14:12 -0700 Subject: [PATCH] Simplify formula, getting rid of scalar_exp_avg_sq --- .../ASR/pruned_transducer_stateless7/optim.py | 42 +++++++++---------- 1 file changed, 21 insertions(+), 21 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py index 807a86750..e5c597982 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py @@ -257,9 +257,6 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of state["param_rms"] = param_rms state["scale_exp_avg_sq"] = torch.zeros_like(param_rms) - # scalar exp_avg_sq. not the same as scale_exp_avg_sq, this is used to - # determine the magnitude of the regular step - state["scalar_exp_avg_sq"] = torch.zeros_like(param_rms) state["scale_grads"] = torch.zeros(size_update_period, *param_rms.shape, **kwargs) @@ -280,7 +277,6 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of state["trivial_update"] = trivial_update # so we know whether to zero exp_avg_sq stats. if trivial_update: - logging.info(f"Shape={p.shape}, trivial update.") # TODO: remove return # "zero_step" being a member of state is the sign that this parameter has @@ -511,7 +507,6 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of Zero the exp_avg_sq stats, and set state["zero_step"] to state["step"] """ state["exp_avg_sq"].zero_() - state["scalar_exp_avg_sq"].zero_() state["zero_step"] = state["step"] def _update_lrs(self, @@ -599,7 +594,8 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of # (where the stats permit). scale = (S.clamp(min=eps) / cur_param_var.clamp(min=eps)).sqrt() - logging.info(f"shape={p.shape}, dim={dim}, scale={scale[0].flatten()[::10]}, cur_param_var={cur_param_var[0].flatten()[::10]}, S={S[0].flatten()[::10]}") + if random.random() < 0.01: + logging.info(f"shape={p.shape}, dim={dim}, scale={scale[0].flatten()[::10]}, cur_param_var={cur_param_var[0].flatten()[::10]}, S={S[0].flatten()[::10]}") # scale shape: (batch_size, 1, size, 1, 1) cur_p *= scale @@ -654,6 +650,13 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of # Apply the scales in `cur_scales` to Q for each dim; this reflects the # parameter rms values in the parameter-diagonalized space, that we have # estimated in the loop above. + # + # We normalize the scales in such a way that the Frobenius norm + # after projecting (grad / denom) with Q should be unchanged, i.e. the + # same as (grad / denom), which is equivalent to having rms=1.0 due + # to how denom is constructed. This simplifies the normalization of the overall + # scale of the parameter change: we just have to multiply by the learning + # rate and param_rms. for dim in range(1, ndim): if cur_scales[dim] is not None: size = p.shape[dim] @@ -661,6 +664,16 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of (batch_size, num_blocks, block_size, block_size) = Q.shape scale = cur_scales[dim].reshape(batch_size, num_blocks, block_size, 1) + # The following normalization step will ensure the Frobenius + # norm is unchanged, from applying this scale: at least, + # assuming "grad / denom" gives uncorrelated outputs so that + # they will have equal variances after projecting to the space + # where the parameter var is diagonalized... this is *roughly* + # true because the gradients at the point where we compute "grad + # / denom" should be decorrelated at least considering + # individual tensor dims + scale /= _mean(scale**2, exclude_dims=[0], keepdim=True).sqrt() + # Q is indexed [batch_index, block_index, diagonalized_coordinate, canonical_coordinate], # want to multiply on the diagonalized co-ordinate. # else: Q is indexed [batch_index, canonical_coordinate]. @@ -825,21 +838,8 @@ param_rms_smooth1: Smoothing proportion for parameter matrix, if assumed rank of # and project back.. grad = self._project(grad, state, forward=False) - scalar_exp_avg_sq = state["scalar_exp_avg_sq"] - scalar_exp_avg_sq.mul_(beta2).add_(_mean(grad**2, - exclude_dims=[0], - keepdim=True), - alpha=1-beta2) - - if bias_correction2 < 0.99: - # scalar_exp_avg_sq is also zeroed at step "zero_step", so - # use the same bias_correction2. - scalar_exp_avg_sq = scalar_exp_avg_sq * (1.0 / bias_correction2) - - denom = scalar_exp_avg_sq.sqrt() + eps # scalar, or [batch_size,1,1..] - - alpha = state["param_rms"] * (1-beta1) * -lr / denom + alpha = -lr * (1-beta1) * state["param_rms"] delta = state["delta"] delta.add_(grad * alpha) @@ -1310,7 +1310,7 @@ class Cain(Optimizer): bias_correction2 = 1 - beta2 ** step denom = (exp_avg_sq.sqrt()).add_(eps) this_delta = grad / denom - alpha = -lr*(1-beta1)*(bias_correction2 ** 0.5) + alpha = -lr * (1-beta1) * (bias_correction2 ** 0.5) delta.add_(this_delta, alpha=alpha) p.add_(delta)