More drafts for rebalancing code

This commit is contained in:
Daniel Povey 2022-06-01 13:58:42 +08:00
parent 9c9bf4f1e3
commit 03e07e80ce

View File

@ -20,6 +20,7 @@ from lhotse.utils import fix_random_seed
import torch
from torch import Tensor
from torch.optim import Optimizer
from icefall import checkpoint
@ -750,6 +751,28 @@ class Eden(LRScheduler):
return [x * factor for x in self.base_lrs]
def average_checkpoints_with_averaged_model_rebalancing(model: torch.nn.Module,
filename_start: str,
filename_end: str,
device: torch.device):
"""
This is a version of average_checkpoints_with_averaged_model() from icefall/checkpoint.py,
that does `rebalancing`, see function Cain.rebalance() for more explanation.
Args:
model: the model with arbitrary parameters, to be overwritten, but on device `device`.
filename_start: the filename for the model at the start of the average period, e.g.
'foo/epoch-10.pt'
filename_end: the filename for the model at the end of the average period, e.g.
'foo/epoch-15.pt'
"""
optimizer = Cain(model.parameters())
def _test_eden():
m = torch.nn.Linear(100, 100)
optim = Cain(m.parameters(), lr=0.003)