mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
Use GPU for averaging checkpoints if possible. (#84)
This commit is contained in:
parent
712ead8207
commit
8cb7f712e4
@ -640,7 +640,8 @@ def main():
|
|||||||
if start >= 0:
|
if start >= 0:
|
||||||
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
||||||
logging.info(f"averaging {filenames}")
|
logging.info(f"averaging {filenames}")
|
||||||
model.load_state_dict(average_checkpoints(filenames))
|
model.to(device)
|
||||||
|
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||||
|
|
||||||
if params.export:
|
if params.export:
|
||||||
logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt")
|
logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt")
|
||||||
|
@ -457,7 +457,8 @@ def main():
|
|||||||
if start >= 0:
|
if start >= 0:
|
||||||
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
||||||
logging.info(f"averaging {filenames}")
|
logging.info(f"averaging {filenames}")
|
||||||
model.load_state_dict(average_checkpoints(filenames))
|
model.to(device)
|
||||||
|
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||||
|
|
||||||
if params.export:
|
if params.export:
|
||||||
logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt")
|
logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt")
|
||||||
|
@ -120,22 +120,26 @@ def load_checkpoint(
|
|||||||
return checkpoint
|
return checkpoint
|
||||||
|
|
||||||
|
|
||||||
def average_checkpoints(filenames: List[Path]) -> dict:
|
def average_checkpoints(
|
||||||
|
filenames: List[Path], device: torch.device = torch.device("cpu")
|
||||||
|
) -> dict:
|
||||||
"""Average a list of checkpoints.
|
"""Average a list of checkpoints.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
filenames:
|
filenames:
|
||||||
Filenames of the checkpoints to be averaged. We assume all
|
Filenames of the checkpoints to be averaged. We assume all
|
||||||
checkpoints are saved by :func:`save_checkpoint`.
|
checkpoints are saved by :func:`save_checkpoint`.
|
||||||
|
device:
|
||||||
|
Move checkpoints to this device before averaging.
|
||||||
Returns:
|
Returns:
|
||||||
Return a dict (i.e., state_dict) which is the average of all
|
Return a dict (i.e., state_dict) which is the average of all
|
||||||
model state dicts contained in the checkpoints.
|
model state dicts contained in the checkpoints.
|
||||||
"""
|
"""
|
||||||
n = len(filenames)
|
n = len(filenames)
|
||||||
|
|
||||||
avg = torch.load(filenames[0], map_location="cpu")["model"]
|
avg = torch.load(filenames[0], map_location=device)["model"]
|
||||||
for i in range(1, n):
|
for i in range(1, n):
|
||||||
state_dict = torch.load(filenames[i], map_location="cpu")["model"]
|
state_dict = torch.load(filenames[i], map_location=device)["model"]
|
||||||
for k in avg:
|
for k in avg:
|
||||||
avg[k] += state_dict[k]
|
avg[k] += state_dict[k]
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user