mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-08 09:32:20 +00:00
Fix loading checkpoint in DDP training.
This commit is contained in:
parent
78bb65ed78
commit
d3101fb005
@ -153,8 +153,8 @@ def get_params() -> AttributeDict:
|
||||
def load_checkpoint_if_available(
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
optimizer: torch.optim.Optimizer,
|
||||
scheduler: torch.optim.lr_scheduler._LRScheduler = None,
|
||||
optimizer: Optional[torch.optim.Optimizer] = None,
|
||||
scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
|
||||
) -> None:
|
||||
"""Load checkpoint from file.
|
||||
|
||||
@ -198,6 +198,8 @@ def load_checkpoint_if_available(
|
||||
for k in keys:
|
||||
params[k] = saved_params[k]
|
||||
|
||||
return saved_params
|
||||
|
||||
|
||||
def save_checkpoint(
|
||||
params: AttributeDict,
|
||||
@ -485,6 +487,9 @@ def run(rank, world_size, args):
|
||||
num_classes=max_phone_id + 1, # +1 for the blank symbol
|
||||
subsampling_factor=params.subsampling_factor,
|
||||
)
|
||||
|
||||
checkpoints = load_checkpoint_if_available(params=params, model=model)
|
||||
|
||||
model.to(device)
|
||||
if world_size > 1:
|
||||
model = DDP(model, device_ids=[rank])
|
||||
@ -496,9 +501,8 @@ def run(rank, world_size, args):
|
||||
)
|
||||
scheduler = StepLR(optimizer, step_size=8, gamma=0.1)
|
||||
|
||||
load_checkpoint_if_available(
|
||||
params=params, model=model, optimizer=optimizer
|
||||
)
|
||||
optimizer.load_state_dict(checkpoints["optimizer"])
|
||||
scheduler.load_state_dict(checkpoints["scheduler"])
|
||||
|
||||
librispeech = LibriSpeechAsrDataModule(args)
|
||||
train_dl = librispeech.train_dataloaders()
|
||||
|
@ -94,7 +94,7 @@ def load_checkpoint(
|
||||
s = checkpoint[name]
|
||||
if obj and s:
|
||||
obj.load_state_dict(s)
|
||||
checkpoint.pop(name)
|
||||
checkpoint.pop(name)
|
||||
|
||||
load("optimizer", optimizer)
|
||||
load("scheduler", scheduler)
|
||||
|
Loading…
x
Reference in New Issue
Block a user