Fix loading checkpoint in DDP training.

This commit is contained in:
Fangjun Kuang 2021-07-26 08:08:14 +08:00
parent 78bb65ed78
commit d3101fb005
2 changed files with 10 additions and 6 deletions

View File

@ -153,8 +153,8 @@ def get_params() -> AttributeDict:
def load_checkpoint_if_available(
params: AttributeDict,
model: nn.Module,
optimizer: torch.optim.Optimizer,
scheduler: torch.optim.lr_scheduler._LRScheduler = None,
optimizer: Optional[torch.optim.Optimizer] = None,
scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
) -> None:
"""Load checkpoint from file.
@ -198,6 +198,8 @@ def load_checkpoint_if_available(
for k in keys:
params[k] = saved_params[k]
return saved_params
def save_checkpoint(
params: AttributeDict,
@ -485,6 +487,9 @@ def run(rank, world_size, args):
num_classes=max_phone_id + 1, # +1 for the blank symbol
subsampling_factor=params.subsampling_factor,
)
checkpoints = load_checkpoint_if_available(params=params, model=model)
model.to(device)
if world_size > 1:
model = DDP(model, device_ids=[rank])
@ -496,9 +501,8 @@ def run(rank, world_size, args):
)
scheduler = StepLR(optimizer, step_size=8, gamma=0.1)
load_checkpoint_if_available(
params=params, model=model, optimizer=optimizer
)
optimizer.load_state_dict(checkpoints["optimizer"])
scheduler.load_state_dict(checkpoints["scheduler"])
librispeech = LibriSpeechAsrDataModule(args)
train_dl = librispeech.train_dataloaders()

View File

@ -94,7 +94,7 @@ def load_checkpoint(
s = checkpoint[name]
if obj and s:
obj.load_state_dict(s)
checkpoint.pop(name)
checkpoint.pop(name)
load("optimizer", optimizer)
load("scheduler", scheduler)