minor fix

This commit is contained in:
yaozengwei 2022-05-02 00:50:32 +08:00
parent fba9ae0502
commit 08b37e07a4

View File

@ -81,7 +81,6 @@ from icefall.checkpoint import (
average_checkpoints_with_averaged_model, average_checkpoints_with_averaged_model,
find_checkpoints, find_checkpoints,
load_checkpoint, load_checkpoint,
load_checkpoint_with_averaged_model,
) )
from icefall.utils import ( from icefall.utils import (
AttributeDict, AttributeDict,
@ -481,6 +480,9 @@ def main():
params.suffix += f"-context-{params.context_size}" params.suffix += f"-context-{params.context_size}"
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
if params.use_averaged_model:
params.suffix += "-use-averaged-model"
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
logging.info("Decoding started") logging.info("Decoding started")
@ -534,15 +536,14 @@ def main():
model.load_state_dict(average_checkpoints(filenames, device=device)) model.load_state_dict(average_checkpoints(filenames, device=device))
else: else:
assert params.iter == 0 assert params.iter == 0
if params.avg == 1: if True:
filename = f"{params.exp_dir}/epoch-{params.epoch}.pt"
load_checkpoint_with_averaged_model(filename, model)
else:
assert params.avg > 1
start = params.epoch - params.avg start = params.epoch - params.avg
filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_start = f"{params.exp_dir}/epoch-{start}.pt"
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
logging.info(f"averaging {filename_start} and {filename_end}") logging.info(
f"averaging modes over range with {filename_start} (excluded) "
f"and {filename_end}"
)
model.to(device) model.to(device)
model.load_state_dict( model.load_state_dict(
average_checkpoints_with_averaged_model( average_checkpoints_with_averaged_model(