Remove some rebalancing code that I am now not going to use.

This commit is contained in:
Daniel Povey 2022-06-01 22:19:28 +08:00
parent 0c73664aef
commit b1f6797af1

View File

@ -20,7 +20,6 @@ from lhotse.utils import fix_random_seed
import torch
from torch import Tensor
from torch.optim import Optimizer
from icefall import checkpoint
@ -273,59 +272,6 @@ class Cain(Optimizer):
return loss
@torch.no_grad()
def rebalance(self):
"""
This function uses the optimizer for a non-optimization-related purpose.
The intended use-case is:
you compute an averaged-over-time model, e.g. by calling average_checkpoints()
or average_checkpoints_with_averaged_model(); and then you want to adjust
the parameter norms to be the same as the final model (because the norms
will tend to be reduced by the parameter averaging); and you want to do this
separately for the "more important" and "less important" directions in
parameter space. This is done by using the co-ordinate changes stored in
this object (which will tend to put the more important directions in lower
indexes), and then rebalancing in blocks, of, say, 16x16.
You should call this function after attaching the optimizer to the
averaged-over-time model; and you should attach the parameters of the "final"
model as the .grad of each parameter.
This function will set the .grad to None.
"""
for group in self.param_groups:
for p in group["params"]:
if p.grad is None or p.numel() == 1:
continue
state = self.state[p]
step = state["step"]
state["step"] = 1 # this prevents it from doing any updates,
# we'll restore it later
p_averaged = p
p_final = p.grad
p_averaged_normalized = self._change_coordinates(p_averaged,
state,
forward=True)
p_final_normalized = self._change_coordinates(p_final, state,
forward=True)
# now do block-by-block normalization (TODO- not done yet)
p_averaged_normalized = self._change_coordinates(p_averaged,
state,
forward=True)
p_averaged = self._change_coordinates(p_averaged_normalized,
state, forward=False)
p[:] = p_averaged
state["step"] = step # restore the original state
def _change_coordinates(self,
grad: Tensor,
@ -751,40 +697,6 @@ class Eden(LRScheduler):
return [x * factor for x in self.base_lrs]
def average_checkpoints_with_averaged_model_rebalancing(model: torch.nn.Module,
filename_start: str,
filename_end: str,
device: torch.device):
"""
This is a version of average_checkpoints_with_averaged_model() from icefall/checkpoint.py,
that does `rebalancing`, see function Cain.rebalance() for more explanation.
Args:
model: the model with arbitrary parameters, to be overwritten, but on device `device`.
filename_start: the filename for the model at the start of the average period, e.g.
'foo/epoch-10.pt'
filename_end: the filename for the model at the end of the average period, e.g.
'foo/epoch-15.pt'
"""
optimizer = Cain(model.parameters())
averaged_state_dict = checkpoint.average_checkpoints_with_averaged_model(
filename_start, filename_end, device, decompose=True)
model.load_state_dict(averaged_state_dict)
c = torch.load(filename_end, map_location=device)
final_state_dict = c["model"]
# attach the final model's parameters as grads, so they will be visible
# to the optimizer object when we call rebalance()
for name, param in model.named_parameters():
param.grad = final_state_dict[name]
def _test_eden():
m = torch.nn.Linear(100, 100)
optim = Cain(m.parameters(), lr=0.003)