Use GPU for averaging checkpoints if possible.

This commit is contained in:
Fangjun Kuang 2021-10-18 11:32:10 +08:00
parent bd7c2f7645
commit 47519b36a4
3 changed files with 11 additions and 5 deletions

View File

@ -638,7 +638,8 @@ def main():
if start >= 0:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.load_state_dict(average_checkpoints(filenames))
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
if params.export:
logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt")

View File

@ -455,7 +455,8 @@ def main():
if start >= 0:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.load_state_dict(average_checkpoints(filenames))
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
if params.export:
logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt")

View File

@ -120,22 +120,26 @@ def load_checkpoint(
return checkpoint
def average_checkpoints(filenames: List[Path]) -> dict:
def average_checkpoints(
filenames: List[Path], device: torch.device = torch.device("cpu")
) -> dict:
"""Average a list of checkpoints.
Args:
filenames:
Filenames of the checkpoints to be averaged. We assume all
checkpoints are saved by :func:`save_checkpoint`.
device:
Move checkpoints to this device before averaging.
Returns:
Return a dict (i.e., state_dict) which is the average of all
model state dicts contained in the checkpoints.
"""
n = len(filenames)
avg = torch.load(filenames[0], map_location="cpu")["model"]
avg = torch.load(filenames[0], map_location=device)["model"]
for i in range(1, n):
state_dict = torch.load(filenames[i], map_location="cpu")["model"]
state_dict = torch.load(filenames[i], map_location=device)["model"]
for k in avg:
avg[k] += state_dict[k]