Use GPU for averaging checkpoints if possible. (#84)

This commit is contained in:
Fangjun Kuang 2021-10-26 17:10:04 +08:00 committed by GitHub
parent 712ead8207
commit 8cb7f712e4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 11 additions and 5 deletions

View File

@ -640,7 +640,8 @@ def main():
if start >= 0:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.load_state_dict(average_checkpoints(filenames))
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
if params.export:
logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt")

View File

@ -457,7 +457,8 @@ def main():
if start >= 0:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.load_state_dict(average_checkpoints(filenames))
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
if params.export:
logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt")

View File

@ -120,22 +120,26 @@ def load_checkpoint(
return checkpoint
def average_checkpoints(filenames: List[Path]) -> dict:
def average_checkpoints(
filenames: List[Path], device: torch.device = torch.device("cpu")
) -> dict:
"""Average a list of checkpoints.
Args:
filenames:
Filenames of the checkpoints to be averaged. We assume all
checkpoints are saved by :func:`save_checkpoint`.
device:
Move checkpoints to this device before averaging.
Returns:
Return a dict (i.e., state_dict) which is the average of all
model state dicts contained in the checkpoints.
"""
n = len(filenames)
avg = torch.load(filenames[0], map_location="cpu")["model"]
avg = torch.load(filenames[0], map_location=device)["model"]
for i in range(1, n):
state_dict = torch.load(filenames[i], map_location="cpu")["model"]
state_dict = torch.load(filenames[i], map_location=device)["model"]
for k in avg:
avg[k] += state_dict[k]