mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-07 08:04:18 +00:00
minor fix
This commit is contained in:
parent
fba9ae0502
commit
08b37e07a4
@ -81,7 +81,6 @@ from icefall.checkpoint import (
|
|||||||
average_checkpoints_with_averaged_model,
|
average_checkpoints_with_averaged_model,
|
||||||
find_checkpoints,
|
find_checkpoints,
|
||||||
load_checkpoint,
|
load_checkpoint,
|
||||||
load_checkpoint_with_averaged_model,
|
|
||||||
)
|
)
|
||||||
from icefall.utils import (
|
from icefall.utils import (
|
||||||
AttributeDict,
|
AttributeDict,
|
||||||
@ -481,6 +480,9 @@ def main():
|
|||||||
params.suffix += f"-context-{params.context_size}"
|
params.suffix += f"-context-{params.context_size}"
|
||||||
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
|
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
|
||||||
|
|
||||||
|
if params.use_averaged_model:
|
||||||
|
params.suffix += "-use-averaged-model"
|
||||||
|
|
||||||
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
|
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
|
||||||
logging.info("Decoding started")
|
logging.info("Decoding started")
|
||||||
|
|
||||||
@ -534,15 +536,14 @@ def main():
|
|||||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||||
else:
|
else:
|
||||||
assert params.iter == 0
|
assert params.iter == 0
|
||||||
if params.avg == 1:
|
if True:
|
||||||
filename = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
|
||||||
load_checkpoint_with_averaged_model(filename, model)
|
|
||||||
else:
|
|
||||||
assert params.avg > 1
|
|
||||||
start = params.epoch - params.avg
|
start = params.epoch - params.avg
|
||||||
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
||||||
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
||||||
logging.info(f"averaging {filename_start} and {filename_end}")
|
logging.info(
|
||||||
|
f"averaging modes over range with {filename_start} (excluded) "
|
||||||
|
f"and {filename_end}"
|
||||||
|
)
|
||||||
model.to(device)
|
model.to(device)
|
||||||
model.load_state_dict(
|
model.load_state_dict(
|
||||||
average_checkpoints_with_averaged_model(
|
average_checkpoints_with_averaged_model(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user