Some drafts of rebalancing code in optim.py

This commit is contained in:
Daniel Povey 2022-06-01 11:34:19 +08:00
parent bc5c782294
commit 9c9bf4f1e3

View File

@ -272,6 +272,59 @@ class Cain(Optimizer):
return loss
@torch.no_grad()
def rebalance(self):
"""
This function uses the optimizer for a non-optimization-related purpose.
The intended use-case is:
you compute an averaged-over-time model, e.g. by calling average_checkpoints()
or average_checkpoints_with_averaged_model(); and then you want to adjust
the parameter norms to be the same as the final model (because the norms
will tend to be reduced by the parameter averaging); and you want to do this
separately for the "more important" and "less important" directions in
parameter space. This is done by using the co-ordinate changes stored in
this object (which will tend to put the more important directions in lower
indexes), and then rebalancing in blocks, of, say, 16x16.
You should call this function after attaching the optimizer to the
averaged-over-time model; and you should attach the parameters of the "final"
model as the .grad of each parameter.
This function will set the .grad to None.
"""
for group in self.param_groups:
for p in group["params"]:
if p.grad is None or p.numel() == 1:
continue
state = self.state[p]
step = state["step"]
state["step"] = 1 # this prevents it from doing any updates,
# we'll restore it later
p_averaged = p
p_final = p.grad
p_averaged_normalized = self._change_coordinates(p_averaged,
state,
forward=True)
p_final_normalized = self._change_coordinates(p_final, state,
forward=True)
# now do block-by-block normalization
p_averaged_normalized = self._change_coordinates(p_averaged,
state,
forward=True)
p_averaged = self._change_coordinates(p_averaged_normalized,
state, forward=False)
p[:] = p_averaged
state["step"] = step # restore the original state
def _change_coordinates(self,
grad: Tensor,