diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py index 7d17386d5..409425f5a 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py @@ -20,7 +20,6 @@ from lhotse.utils import fix_random_seed import torch from torch import Tensor from torch.optim import Optimizer -from icefall import checkpoint @@ -273,59 +272,6 @@ class Cain(Optimizer): return loss - @torch.no_grad() - def rebalance(self): - """ - - This function uses the optimizer for a non-optimization-related purpose. - The intended use-case is: - you compute an averaged-over-time model, e.g. by calling average_checkpoints() - or average_checkpoints_with_averaged_model(); and then you want to adjust - the parameter norms to be the same as the final model (because the norms - will tend to be reduced by the parameter averaging); and you want to do this - separately for the "more important" and "less important" directions in - parameter space. This is done by using the co-ordinate changes stored in - this object (which will tend to put the more important directions in lower - indexes), and then rebalancing in blocks, of, say, 16x16. - - You should call this function after attaching the optimizer to the - averaged-over-time model; and you should attach the parameters of the "final" - model as the .grad of each parameter. - - This function will set the .grad to None. - """ - for group in self.param_groups: - - for p in group["params"]: - if p.grad is None or p.numel() == 1: - continue - - state = self.state[p] - step = state["step"] - state["step"] = 1 # this prevents it from doing any updates, - # we'll restore it later - p_averaged = p - p_final = p.grad - - p_averaged_normalized = self._change_coordinates(p_averaged, - state, - forward=True) - p_final_normalized = self._change_coordinates(p_final, state, - forward=True) - - # now do block-by-block normalization (TODO- not done yet) - - - p_averaged_normalized = self._change_coordinates(p_averaged, - state, - forward=True) - p_averaged = self._change_coordinates(p_averaged_normalized, - state, forward=False) - p[:] = p_averaged - - state["step"] = step # restore the original state - - def _change_coordinates(self, grad: Tensor, @@ -751,40 +697,6 @@ class Eden(LRScheduler): return [x * factor for x in self.base_lrs] -def average_checkpoints_with_averaged_model_rebalancing(model: torch.nn.Module, - filename_start: str, - filename_end: str, - device: torch.device): - """ - This is a version of average_checkpoints_with_averaged_model() from icefall/checkpoint.py, - that does `rebalancing`, see function Cain.rebalance() for more explanation. - - Args: - model: the model with arbitrary parameters, to be overwritten, but on device `device`. - filename_start: the filename for the model at the start of the average period, e.g. - 'foo/epoch-10.pt' - filename_end: the filename for the model at the end of the average period, e.g. - 'foo/epoch-15.pt' - """ - - optimizer = Cain(model.parameters()) - - averaged_state_dict = checkpoint.average_checkpoints_with_averaged_model( - filename_start, filename_end, device, decompose=True) - model.load_state_dict(averaged_state_dict) - - c = torch.load(filename_end, map_location=device) - final_state_dict = c["model"] - - # attach the final model's parameters as grads, so they will be visible - # to the optimizer object when we call rebalance() - for name, param in model.named_parameters(): - param.grad = final_state_dict[name] - - - - - def _test_eden(): m = torch.nn.Linear(100, 100) optim = Cain(m.parameters(), lr=0.003)