from local

This commit is contained in:
dohe0342 2023-02-14 15:13:19 +09:00
parent 51b1b74671
commit 771848696f
2 changed files with 12 additions and 11 deletions

View File

@ -778,6 +778,17 @@ def main() -> None:
group_num=params.group_num, group_num=params.group_num,
) )
# we need cut ids to display recognition results.
args.return_cuts = True
tedlium = TedLiumAsrDataModule(args)
valid_cuts = tedlium.dev_cuts()
valid_dl = tedlium.valid_dataloaders(valid_cuts)
test_sets = ["dev"]
test_dls = [valid_dl]
for epoch in range(params.start, params.end+1): for epoch in range(params.start, params.end+1):
load_checkpoint(f"{params.exp_dir}/epoch-{epoch}.pt", model) load_checkpoint(f"{params.exp_dir}/epoch-{epoch}.pt", model)
@ -786,17 +797,7 @@ def main() -> None:
num_param = sum([p.numel() for p in model.parameters()]) num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}") logging.info(f"Number of model parameters: {num_param}")
# we need cut ids to display recognition results.
args.return_cuts = True
tedlium = TedLiumAsrDataModule(args)
valid_cuts = tedlium.dev_cuts()
valid_dl = tedlium.valid_dataloaders(valid_cuts)
test_sets = ["dev"]
test_dls = [dev_dl]
for test_set, test_dl in zip(test_sets, test_dls): for test_set, test_dl in zip(test_sets, test_dls):
results_dict = decode_dataset( results_dict = decode_dataset(
dl=test_dl, dl=test_dl,