mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-07 08:04:18 +00:00
update decode file
This commit is contained in:
parent
08b37e07a4
commit
aea8a03e00
@ -536,22 +536,21 @@ def main():
|
|||||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||||
else:
|
else:
|
||||||
assert params.iter == 0
|
assert params.iter == 0
|
||||||
if True:
|
start = params.epoch - params.avg
|
||||||
start = params.epoch - params.avg
|
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
||||||
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
||||||
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
logging.info(
|
||||||
logging.info(
|
f"averaging modes over range with {filename_start} (excluded) "
|
||||||
f"averaging modes over range with {filename_start} (excluded) "
|
f"and {filename_end}"
|
||||||
f"and {filename_end}"
|
)
|
||||||
)
|
model.to(device)
|
||||||
model.to(device)
|
model.load_state_dict(
|
||||||
model.load_state_dict(
|
average_checkpoints_with_averaged_model(
|
||||||
average_checkpoints_with_averaged_model(
|
filename_start=filename_start,
|
||||||
filename_start=filename_start,
|
filename_end=filename_end,
|
||||||
filename_end=filename_end,
|
device=device,
|
||||||
device=device,
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
)
|
||||||
|
|
||||||
model.to(device)
|
model.to(device)
|
||||||
model.eval()
|
model.eval()
|
||||||
|
@ -416,8 +416,9 @@ def average_checkpoints_with_averaged_model(
|
|||||||
weight_start = -batch_idx_train_start / interval
|
weight_start = -batch_idx_train_start / interval
|
||||||
weight_end = batch_idx_train_end / interval
|
weight_end = batch_idx_train_end / interval
|
||||||
|
|
||||||
avg = state_dict_end["model_avg"]
|
model_end = state_dict_end["model_avg"]
|
||||||
model_start = state_dict_start["model_avg"]
|
model_start = state_dict_start["model_avg"]
|
||||||
|
avg = model_end
|
||||||
|
|
||||||
# Identify shared parameters. Two parameters are said to be shared
|
# Identify shared parameters. Two parameters are said to be shared
|
||||||
# if they have the same data_ptr
|
# if they have the same data_ptr
|
||||||
@ -434,14 +435,3 @@ def average_checkpoints_with_averaged_model(
|
|||||||
avg[k] += model_start[k] * weight_start
|
avg[k] += model_start[k] * weight_start
|
||||||
|
|
||||||
return avg
|
return avg
|
||||||
|
|
||||||
|
|
||||||
def load_checkpoint_with_averaged_model(
|
|
||||||
filename: str,
|
|
||||||
model: nn.Module,
|
|
||||||
strict: bool = True,
|
|
||||||
) -> None:
|
|
||||||
"""Load checkpoint with aaveraged model."""
|
|
||||||
logging.info(f"Loading checkpoint from {filename}, using averaged model")
|
|
||||||
checkpoint = torch.load(filename, map_location="cpu")
|
|
||||||
model.load_state_dict(checkpoint["model_avg"], strict=strict)
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user