update decode file

This commit is contained in:
yaozengwei 2022-05-02 12:17:43 +08:00
parent 08b37e07a4
commit aea8a03e00
2 changed files with 16 additions and 27 deletions

View File

@ -536,22 +536,21 @@ def main():
model.load_state_dict(average_checkpoints(filenames, device=device))
else:
assert params.iter == 0
if True:
start = params.epoch - params.avg
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
logging.info(
f"averaging modes over range with {filename_start} (excluded) "
f"and {filename_end}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
start = params.epoch - params.avg
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
logging.info(
f"averaging modes over range with {filename_start} (excluded) "
f"and {filename_end}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
model.to(device)
model.eval()

View File

@ -416,8 +416,9 @@ def average_checkpoints_with_averaged_model(
weight_start = -batch_idx_train_start / interval
weight_end = batch_idx_train_end / interval
avg = state_dict_end["model_avg"]
model_end = state_dict_end["model_avg"]
model_start = state_dict_start["model_avg"]
avg = model_end
# Identify shared parameters. Two parameters are said to be shared
# if they have the same data_ptr
@ -434,14 +435,3 @@ def average_checkpoints_with_averaged_model(
avg[k] += model_start[k] * weight_start
return avg
def load_checkpoint_with_averaged_model(
filename: str,
model: nn.Module,
strict: bool = True,
) -> None:
"""Load checkpoint with aaveraged model."""
logging.info(f"Loading checkpoint from {filename}, using averaged model")
checkpoint = torch.load(filename, map_location="cpu")
model.load_state_dict(checkpoint["model_avg"], strict=strict)