mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-09 09:04:19 +00:00
Remove some rebalancing code that I am now not going to use.
This commit is contained in:
parent
0c73664aef
commit
b1f6797af1
@ -20,7 +20,6 @@ from lhotse.utils import fix_random_seed
|
|||||||
import torch
|
import torch
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch.optim import Optimizer
|
from torch.optim import Optimizer
|
||||||
from icefall import checkpoint
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@ -273,59 +272,6 @@ class Cain(Optimizer):
|
|||||||
|
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def rebalance(self):
|
|
||||||
"""
|
|
||||||
|
|
||||||
This function uses the optimizer for a non-optimization-related purpose.
|
|
||||||
The intended use-case is:
|
|
||||||
you compute an averaged-over-time model, e.g. by calling average_checkpoints()
|
|
||||||
or average_checkpoints_with_averaged_model(); and then you want to adjust
|
|
||||||
the parameter norms to be the same as the final model (because the norms
|
|
||||||
will tend to be reduced by the parameter averaging); and you want to do this
|
|
||||||
separately for the "more important" and "less important" directions in
|
|
||||||
parameter space. This is done by using the co-ordinate changes stored in
|
|
||||||
this object (which will tend to put the more important directions in lower
|
|
||||||
indexes), and then rebalancing in blocks, of, say, 16x16.
|
|
||||||
|
|
||||||
You should call this function after attaching the optimizer to the
|
|
||||||
averaged-over-time model; and you should attach the parameters of the "final"
|
|
||||||
model as the .grad of each parameter.
|
|
||||||
|
|
||||||
This function will set the .grad to None.
|
|
||||||
"""
|
|
||||||
for group in self.param_groups:
|
|
||||||
|
|
||||||
for p in group["params"]:
|
|
||||||
if p.grad is None or p.numel() == 1:
|
|
||||||
continue
|
|
||||||
|
|
||||||
state = self.state[p]
|
|
||||||
step = state["step"]
|
|
||||||
state["step"] = 1 # this prevents it from doing any updates,
|
|
||||||
# we'll restore it later
|
|
||||||
p_averaged = p
|
|
||||||
p_final = p.grad
|
|
||||||
|
|
||||||
p_averaged_normalized = self._change_coordinates(p_averaged,
|
|
||||||
state,
|
|
||||||
forward=True)
|
|
||||||
p_final_normalized = self._change_coordinates(p_final, state,
|
|
||||||
forward=True)
|
|
||||||
|
|
||||||
# now do block-by-block normalization (TODO- not done yet)
|
|
||||||
|
|
||||||
|
|
||||||
p_averaged_normalized = self._change_coordinates(p_averaged,
|
|
||||||
state,
|
|
||||||
forward=True)
|
|
||||||
p_averaged = self._change_coordinates(p_averaged_normalized,
|
|
||||||
state, forward=False)
|
|
||||||
p[:] = p_averaged
|
|
||||||
|
|
||||||
state["step"] = step # restore the original state
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def _change_coordinates(self,
|
def _change_coordinates(self,
|
||||||
grad: Tensor,
|
grad: Tensor,
|
||||||
@ -751,40 +697,6 @@ class Eden(LRScheduler):
|
|||||||
return [x * factor for x in self.base_lrs]
|
return [x * factor for x in self.base_lrs]
|
||||||
|
|
||||||
|
|
||||||
def average_checkpoints_with_averaged_model_rebalancing(model: torch.nn.Module,
|
|
||||||
filename_start: str,
|
|
||||||
filename_end: str,
|
|
||||||
device: torch.device):
|
|
||||||
"""
|
|
||||||
This is a version of average_checkpoints_with_averaged_model() from icefall/checkpoint.py,
|
|
||||||
that does `rebalancing`, see function Cain.rebalance() for more explanation.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model: the model with arbitrary parameters, to be overwritten, but on device `device`.
|
|
||||||
filename_start: the filename for the model at the start of the average period, e.g.
|
|
||||||
'foo/epoch-10.pt'
|
|
||||||
filename_end: the filename for the model at the end of the average period, e.g.
|
|
||||||
'foo/epoch-15.pt'
|
|
||||||
"""
|
|
||||||
|
|
||||||
optimizer = Cain(model.parameters())
|
|
||||||
|
|
||||||
averaged_state_dict = checkpoint.average_checkpoints_with_averaged_model(
|
|
||||||
filename_start, filename_end, device, decompose=True)
|
|
||||||
model.load_state_dict(averaged_state_dict)
|
|
||||||
|
|
||||||
c = torch.load(filename_end, map_location=device)
|
|
||||||
final_state_dict = c["model"]
|
|
||||||
|
|
||||||
# attach the final model's parameters as grads, so they will be visible
|
|
||||||
# to the optimizer object when we call rebalance()
|
|
||||||
for name, param in model.named_parameters():
|
|
||||||
param.grad = final_state_dict[name]
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def _test_eden():
|
def _test_eden():
|
||||||
m = torch.nn.Linear(100, 100)
|
m = torch.nn.Linear(100, 100)
|
||||||
optim = Cain(m.parameters(), lr=0.003)
|
optim = Cain(m.parameters(), lr=0.003)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user