mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-08 09:32:20 +00:00
Fix loading checkpoint in DDP training.
This commit is contained in:
parent
78bb65ed78
commit
d3101fb005
@ -153,8 +153,8 @@ def get_params() -> AttributeDict:
|
|||||||
def load_checkpoint_if_available(
|
def load_checkpoint_if_available(
|
||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
optimizer: torch.optim.Optimizer,
|
optimizer: Optional[torch.optim.Optimizer] = None,
|
||||||
scheduler: torch.optim.lr_scheduler._LRScheduler = None,
|
scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Load checkpoint from file.
|
"""Load checkpoint from file.
|
||||||
|
|
||||||
@ -198,6 +198,8 @@ def load_checkpoint_if_available(
|
|||||||
for k in keys:
|
for k in keys:
|
||||||
params[k] = saved_params[k]
|
params[k] = saved_params[k]
|
||||||
|
|
||||||
|
return saved_params
|
||||||
|
|
||||||
|
|
||||||
def save_checkpoint(
|
def save_checkpoint(
|
||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
@ -485,6 +487,9 @@ def run(rank, world_size, args):
|
|||||||
num_classes=max_phone_id + 1, # +1 for the blank symbol
|
num_classes=max_phone_id + 1, # +1 for the blank symbol
|
||||||
subsampling_factor=params.subsampling_factor,
|
subsampling_factor=params.subsampling_factor,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
checkpoints = load_checkpoint_if_available(params=params, model=model)
|
||||||
|
|
||||||
model.to(device)
|
model.to(device)
|
||||||
if world_size > 1:
|
if world_size > 1:
|
||||||
model = DDP(model, device_ids=[rank])
|
model = DDP(model, device_ids=[rank])
|
||||||
@ -496,9 +501,8 @@ def run(rank, world_size, args):
|
|||||||
)
|
)
|
||||||
scheduler = StepLR(optimizer, step_size=8, gamma=0.1)
|
scheduler = StepLR(optimizer, step_size=8, gamma=0.1)
|
||||||
|
|
||||||
load_checkpoint_if_available(
|
optimizer.load_state_dict(checkpoints["optimizer"])
|
||||||
params=params, model=model, optimizer=optimizer
|
scheduler.load_state_dict(checkpoints["scheduler"])
|
||||||
)
|
|
||||||
|
|
||||||
librispeech = LibriSpeechAsrDataModule(args)
|
librispeech = LibriSpeechAsrDataModule(args)
|
||||||
train_dl = librispeech.train_dataloaders()
|
train_dl = librispeech.train_dataloaders()
|
||||||
|
@ -94,7 +94,7 @@ def load_checkpoint(
|
|||||||
s = checkpoint[name]
|
s = checkpoint[name]
|
||||||
if obj and s:
|
if obj and s:
|
||||||
obj.load_state_dict(s)
|
obj.load_state_dict(s)
|
||||||
checkpoint.pop(name)
|
checkpoint.pop(name)
|
||||||
|
|
||||||
load("optimizer", optimizer)
|
load("optimizer", optimizer)
|
||||||
load("scheduler", scheduler)
|
load("scheduler", scheduler)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user