mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-12 10:34:19 +00:00
Some drafts of rebalancing code in optim.py
This commit is contained in:
parent
bc5c782294
commit
9c9bf4f1e3
@ -272,6 +272,59 @@ class Cain(Optimizer):
|
|||||||
|
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def rebalance(self):
|
||||||
|
"""
|
||||||
|
|
||||||
|
This function uses the optimizer for a non-optimization-related purpose.
|
||||||
|
The intended use-case is:
|
||||||
|
you compute an averaged-over-time model, e.g. by calling average_checkpoints()
|
||||||
|
or average_checkpoints_with_averaged_model(); and then you want to adjust
|
||||||
|
the parameter norms to be the same as the final model (because the norms
|
||||||
|
will tend to be reduced by the parameter averaging); and you want to do this
|
||||||
|
separately for the "more important" and "less important" directions in
|
||||||
|
parameter space. This is done by using the co-ordinate changes stored in
|
||||||
|
this object (which will tend to put the more important directions in lower
|
||||||
|
indexes), and then rebalancing in blocks, of, say, 16x16.
|
||||||
|
|
||||||
|
You should call this function after attaching the optimizer to the
|
||||||
|
averaged-over-time model; and you should attach the parameters of the "final"
|
||||||
|
model as the .grad of each parameter.
|
||||||
|
|
||||||
|
This function will set the .grad to None.
|
||||||
|
"""
|
||||||
|
for group in self.param_groups:
|
||||||
|
|
||||||
|
for p in group["params"]:
|
||||||
|
if p.grad is None or p.numel() == 1:
|
||||||
|
continue
|
||||||
|
|
||||||
|
state = self.state[p]
|
||||||
|
step = state["step"]
|
||||||
|
state["step"] = 1 # this prevents it from doing any updates,
|
||||||
|
# we'll restore it later
|
||||||
|
p_averaged = p
|
||||||
|
p_final = p.grad
|
||||||
|
|
||||||
|
p_averaged_normalized = self._change_coordinates(p_averaged,
|
||||||
|
state,
|
||||||
|
forward=True)
|
||||||
|
p_final_normalized = self._change_coordinates(p_final, state,
|
||||||
|
forward=True)
|
||||||
|
|
||||||
|
# now do block-by-block normalization
|
||||||
|
|
||||||
|
|
||||||
|
p_averaged_normalized = self._change_coordinates(p_averaged,
|
||||||
|
state,
|
||||||
|
forward=True)
|
||||||
|
p_averaged = self._change_coordinates(p_averaged_normalized,
|
||||||
|
state, forward=False)
|
||||||
|
p[:] = p_averaged
|
||||||
|
|
||||||
|
state["step"] = step # restore the original state
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def _change_coordinates(self,
|
def _change_coordinates(self,
|
||||||
grad: Tensor,
|
grad: Tensor,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user