mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-04 06:34:20 +00:00
minor fix
This commit is contained in:
parent
fba9ae0502
commit
08b37e07a4
@ -81,7 +81,6 @@ from icefall.checkpoint import (
|
||||
average_checkpoints_with_averaged_model,
|
||||
find_checkpoints,
|
||||
load_checkpoint,
|
||||
load_checkpoint_with_averaged_model,
|
||||
)
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
@ -481,6 +480,9 @@ def main():
|
||||
params.suffix += f"-context-{params.context_size}"
|
||||
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
|
||||
|
||||
if params.use_averaged_model:
|
||||
params.suffix += "-use-averaged-model"
|
||||
|
||||
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
|
||||
logging.info("Decoding started")
|
||||
|
||||
@ -534,15 +536,14 @@ def main():
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
else:
|
||||
assert params.iter == 0
|
||||
if params.avg == 1:
|
||||
filename = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
||||
load_checkpoint_with_averaged_model(filename, model)
|
||||
else:
|
||||
assert params.avg > 1
|
||||
if True:
|
||||
start = params.epoch - params.avg
|
||||
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
||||
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
||||
logging.info(f"averaging {filename_start} and {filename_end}")
|
||||
logging.info(
|
||||
f"averaging modes over range with {filename_start} (excluded) "
|
||||
f"and {filename_end}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
|
Loading…
x
Reference in New Issue
Block a user