mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-08 09:32:20 +00:00
Use GPU for averaging checkpoints if possible. (#84)
This commit is contained in:
parent
712ead8207
commit
8cb7f712e4
@ -640,7 +640,8 @@ def main():
|
||||
if start >= 0:
|
||||
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.load_state_dict(average_checkpoints(filenames))
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
|
||||
if params.export:
|
||||
logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt")
|
||||
|
@ -457,7 +457,8 @@ def main():
|
||||
if start >= 0:
|
||||
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.load_state_dict(average_checkpoints(filenames))
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
|
||||
if params.export:
|
||||
logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt")
|
||||
|
@ -120,22 +120,26 @@ def load_checkpoint(
|
||||
return checkpoint
|
||||
|
||||
|
||||
def average_checkpoints(filenames: List[Path]) -> dict:
|
||||
def average_checkpoints(
|
||||
filenames: List[Path], device: torch.device = torch.device("cpu")
|
||||
) -> dict:
|
||||
"""Average a list of checkpoints.
|
||||
|
||||
Args:
|
||||
filenames:
|
||||
Filenames of the checkpoints to be averaged. We assume all
|
||||
checkpoints are saved by :func:`save_checkpoint`.
|
||||
device:
|
||||
Move checkpoints to this device before averaging.
|
||||
Returns:
|
||||
Return a dict (i.e., state_dict) which is the average of all
|
||||
model state dicts contained in the checkpoints.
|
||||
"""
|
||||
n = len(filenames)
|
||||
|
||||
avg = torch.load(filenames[0], map_location="cpu")["model"]
|
||||
avg = torch.load(filenames[0], map_location=device)["model"]
|
||||
for i in range(1, n):
|
||||
state_dict = torch.load(filenames[i], map_location="cpu")["model"]
|
||||
state_dict = torch.load(filenames[i], map_location=device)["model"]
|
||||
for k in avg:
|
||||
avg[k] += state_dict[k]
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user