mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
from local
This commit is contained in:
parent
51b1b74671
commit
771848696f
Binary file not shown.
@ -778,14 +778,6 @@ def main() -> None:
|
|||||||
group_num=params.group_num,
|
group_num=params.group_num,
|
||||||
)
|
)
|
||||||
|
|
||||||
for epoch in range(params.start, params.end+1):
|
|
||||||
load_checkpoint(f"{params.exp_dir}/epoch-{epoch}.pt", model)
|
|
||||||
|
|
||||||
model.to(device)
|
|
||||||
model.eval()
|
|
||||||
num_param = sum([p.numel() for p in model.parameters()])
|
|
||||||
logging.info(f"Number of model parameters: {num_param}")
|
|
||||||
|
|
||||||
# we need cut ids to display recognition results.
|
# we need cut ids to display recognition results.
|
||||||
args.return_cuts = True
|
args.return_cuts = True
|
||||||
tedlium = TedLiumAsrDataModule(args)
|
tedlium = TedLiumAsrDataModule(args)
|
||||||
@ -795,7 +787,16 @@ def main() -> None:
|
|||||||
valid_dl = tedlium.valid_dataloaders(valid_cuts)
|
valid_dl = tedlium.valid_dataloaders(valid_cuts)
|
||||||
|
|
||||||
test_sets = ["dev"]
|
test_sets = ["dev"]
|
||||||
test_dls = [dev_dl]
|
test_dls = [valid_dl]
|
||||||
|
|
||||||
|
for epoch in range(params.start, params.end+1):
|
||||||
|
load_checkpoint(f"{params.exp_dir}/epoch-{epoch}.pt", model)
|
||||||
|
|
||||||
|
model.to(device)
|
||||||
|
model.eval()
|
||||||
|
num_param = sum([p.numel() for p in model.parameters()])
|
||||||
|
logging.info(f"Number of model parameters: {num_param}")
|
||||||
|
|
||||||
|
|
||||||
for test_set, test_dl in zip(test_sets, test_dls):
|
for test_set, test_dl in zip(test_sets, test_dls):
|
||||||
results_dict = decode_dataset(
|
results_dict = decode_dataset(
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user