From 9c9bf4f1e309096eeee466b1a9f5566040cb5ff2 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 1 Jun 2022 11:34:19 +0800 Subject: [PATCH] Some drafts of rebalancing code in optim.py --- .../ASR/pruned_transducer_stateless7/optim.py | 53 +++++++++++++++++++ 1 file changed, 53 insertions(+) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py index d7e0193dc..8fca662e2 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py @@ -272,6 +272,59 @@ class Cain(Optimizer): return loss + @torch.no_grad() + def rebalance(self): + """ + + This function uses the optimizer for a non-optimization-related purpose. + The intended use-case is: + you compute an averaged-over-time model, e.g. by calling average_checkpoints() + or average_checkpoints_with_averaged_model(); and then you want to adjust + the parameter norms to be the same as the final model (because the norms + will tend to be reduced by the parameter averaging); and you want to do this + separately for the "more important" and "less important" directions in + parameter space. This is done by using the co-ordinate changes stored in + this object (which will tend to put the more important directions in lower + indexes), and then rebalancing in blocks, of, say, 16x16. + + You should call this function after attaching the optimizer to the + averaged-over-time model; and you should attach the parameters of the "final" + model as the .grad of each parameter. + + This function will set the .grad to None. + """ + for group in self.param_groups: + + for p in group["params"]: + if p.grad is None or p.numel() == 1: + continue + + state = self.state[p] + step = state["step"] + state["step"] = 1 # this prevents it from doing any updates, + # we'll restore it later + p_averaged = p + p_final = p.grad + + p_averaged_normalized = self._change_coordinates(p_averaged, + state, + forward=True) + p_final_normalized = self._change_coordinates(p_final, state, + forward=True) + + # now do block-by-block normalization + + + p_averaged_normalized = self._change_coordinates(p_averaged, + state, + forward=True) + p_averaged = self._change_coordinates(p_averaged_normalized, + state, forward=False) + p[:] = p_averaged + + state["step"] = step # restore the original state + + def _change_coordinates(self, grad: Tensor,