update decode file

This commit is contained in:
yaozengwei 2022-05-02 12:17:43 +08:00
parent 08b37e07a4
commit aea8a03e00
2 changed files with 16 additions and 27 deletions

View File

@ -536,22 +536,21 @@ def main():
model.load_state_dict(average_checkpoints(filenames, device=device)) model.load_state_dict(average_checkpoints(filenames, device=device))
else: else:
assert params.iter == 0 assert params.iter == 0
if True: start = params.epoch - params.avg
start = params.epoch - params.avg filename_start = f"{params.exp_dir}/epoch-{start}.pt"
filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info(
logging.info( f"averaging modes over range with {filename_start} (excluded) "
f"averaging modes over range with {filename_start} (excluded) " f"and {filename_end}"
f"and {filename_end}" )
) model.to(device)
model.to(device) model.load_state_dict(
model.load_state_dict( average_checkpoints_with_averaged_model(
average_checkpoints_with_averaged_model( filename_start=filename_start,
filename_start=filename_start, filename_end=filename_end,
filename_end=filename_end, device=device,
device=device,
)
) )
)
model.to(device) model.to(device)
model.eval() model.eval()

View File

@ -416,8 +416,9 @@ def average_checkpoints_with_averaged_model(
weight_start = -batch_idx_train_start / interval weight_start = -batch_idx_train_start / interval
weight_end = batch_idx_train_end / interval weight_end = batch_idx_train_end / interval
avg = state_dict_end["model_avg"] model_end = state_dict_end["model_avg"]
model_start = state_dict_start["model_avg"] model_start = state_dict_start["model_avg"]
avg = model_end
# Identify shared parameters. Two parameters are said to be shared # Identify shared parameters. Two parameters are said to be shared
# if they have the same data_ptr # if they have the same data_ptr
@ -434,14 +435,3 @@ def average_checkpoints_with_averaged_model(
avg[k] += model_start[k] * weight_start avg[k] += model_start[k] * weight_start
return avg return avg
def load_checkpoint_with_averaged_model(
filename: str,
model: nn.Module,
strict: bool = True,
) -> None:
"""Load checkpoint with aaveraged model."""
logging.info(f"Loading checkpoint from {filename}, using averaged model")
checkpoint = torch.load(filename, map_location="cpu")
model.load_state_dict(checkpoint["model_avg"], strict=strict)