minor fix

This commit is contained in:
yaozengwei 2022-05-02 00:50:32 +08:00
parent fba9ae0502
commit 08b37e07a4

View File

@ -81,7 +81,6 @@ from icefall.checkpoint import (
average_checkpoints_with_averaged_model,
find_checkpoints,
load_checkpoint,
load_checkpoint_with_averaged_model,
)
from icefall.utils import (
AttributeDict,
@ -481,6 +480,9 @@ def main():
params.suffix += f"-context-{params.context_size}"
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
if params.use_averaged_model:
params.suffix += "-use-averaged-model"
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
logging.info("Decoding started")
@ -534,15 +536,14 @@ def main():
model.load_state_dict(average_checkpoints(filenames, device=device))
else:
assert params.iter == 0
if params.avg == 1:
filename = f"{params.exp_dir}/epoch-{params.epoch}.pt"
load_checkpoint_with_averaged_model(filename, model)
else:
assert params.avg > 1
if True:
start = params.epoch - params.avg
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
logging.info(f"averaging {filename_start} and {filename_end}")
logging.info(
f"averaging modes over range with {filename_start} (excluded) "
f"and {filename_end}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(