mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-10 09:34:39 +00:00
More drafts for rebalancing code
This commit is contained in:
parent
9c9bf4f1e3
commit
03e07e80ce
@ -20,6 +20,7 @@ from lhotse.utils import fix_random_seed
|
|||||||
import torch
|
import torch
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch.optim import Optimizer
|
from torch.optim import Optimizer
|
||||||
|
from icefall import checkpoint
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@ -750,6 +751,28 @@ class Eden(LRScheduler):
|
|||||||
return [x * factor for x in self.base_lrs]
|
return [x * factor for x in self.base_lrs]
|
||||||
|
|
||||||
|
|
||||||
|
def average_checkpoints_with_averaged_model_rebalancing(model: torch.nn.Module,
|
||||||
|
filename_start: str,
|
||||||
|
filename_end: str,
|
||||||
|
device: torch.device):
|
||||||
|
"""
|
||||||
|
This is a version of average_checkpoints_with_averaged_model() from icefall/checkpoint.py,
|
||||||
|
that does `rebalancing`, see function Cain.rebalance() for more explanation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: the model with arbitrary parameters, to be overwritten, but on device `device`.
|
||||||
|
filename_start: the filename for the model at the start of the average period, e.g.
|
||||||
|
'foo/epoch-10.pt'
|
||||||
|
filename_end: the filename for the model at the end of the average period, e.g.
|
||||||
|
'foo/epoch-15.pt'
|
||||||
|
"""
|
||||||
|
|
||||||
|
optimizer = Cain(model.parameters())
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def _test_eden():
|
def _test_eden():
|
||||||
m = torch.nn.Linear(100, 100)
|
m = torch.nn.Linear(100, 100)
|
||||||
optim = Cain(m.parameters(), lr=0.003)
|
optim = Cain(m.parameters(), lr=0.003)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user