From 03e07e80ce30311c2e109f152a2991482675e142 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 1 Jun 2022 13:58:42 +0800 Subject: [PATCH] More drafts for rebalancing code --- .../ASR/pruned_transducer_stateless7/optim.py | 23 +++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py index 8fca662e2..03adabc3c 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py @@ -20,6 +20,7 @@ from lhotse.utils import fix_random_seed import torch from torch import Tensor from torch.optim import Optimizer +from icefall import checkpoint @@ -750,6 +751,28 @@ class Eden(LRScheduler): return [x * factor for x in self.base_lrs] +def average_checkpoints_with_averaged_model_rebalancing(model: torch.nn.Module, + filename_start: str, + filename_end: str, + device: torch.device): + """ + This is a version of average_checkpoints_with_averaged_model() from icefall/checkpoint.py, + that does `rebalancing`, see function Cain.rebalance() for more explanation. + + Args: + model: the model with arbitrary parameters, to be overwritten, but on device `device`. + filename_start: the filename for the model at the start of the average period, e.g. + 'foo/epoch-10.pt' + filename_end: the filename for the model at the end of the average period, e.g. + 'foo/epoch-15.pt' + """ + + optimizer = Cain(model.parameters()) + + + + + def _test_eden(): m = torch.nn.Linear(100, 100) optim = Cain(m.parameters(), lr=0.003)