mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-04 06:34:20 +00:00
update decode file
This commit is contained in:
parent
08b37e07a4
commit
aea8a03e00
@ -536,22 +536,21 @@ def main():
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
else:
|
||||
assert params.iter == 0
|
||||
if True:
|
||||
start = params.epoch - params.avg
|
||||
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
||||
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
||||
logging.info(
|
||||
f"averaging modes over range with {filename_start} (excluded) "
|
||||
f"and {filename_end}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
start = params.epoch - params.avg
|
||||
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
||||
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
||||
logging.info(
|
||||
f"averaging modes over range with {filename_start} (excluded) "
|
||||
f"and {filename_end}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
|
||||
model.to(device)
|
||||
model.eval()
|
||||
|
@ -416,8 +416,9 @@ def average_checkpoints_with_averaged_model(
|
||||
weight_start = -batch_idx_train_start / interval
|
||||
weight_end = batch_idx_train_end / interval
|
||||
|
||||
avg = state_dict_end["model_avg"]
|
||||
model_end = state_dict_end["model_avg"]
|
||||
model_start = state_dict_start["model_avg"]
|
||||
avg = model_end
|
||||
|
||||
# Identify shared parameters. Two parameters are said to be shared
|
||||
# if they have the same data_ptr
|
||||
@ -434,14 +435,3 @@ def average_checkpoints_with_averaged_model(
|
||||
avg[k] += model_start[k] * weight_start
|
||||
|
||||
return avg
|
||||
|
||||
|
||||
def load_checkpoint_with_averaged_model(
|
||||
filename: str,
|
||||
model: nn.Module,
|
||||
strict: bool = True,
|
||||
) -> None:
|
||||
"""Load checkpoint with aaveraged model."""
|
||||
logging.info(f"Loading checkpoint from {filename}, using averaged model")
|
||||
checkpoint = torch.load(filename, map_location="cpu")
|
||||
model.load_state_dict(checkpoint["model_avg"], strict=strict)
|
||||
|
Loading…
x
Reference in New Issue
Block a user