diff --git a/egs/tedlium2/ASR/conformer_ctc3/.decode_multi.py.swp b/egs/tedlium2/ASR/conformer_ctc3/.decode_multi.py.swp index cbaf71bb0..5f03243df 100644 Binary files a/egs/tedlium2/ASR/conformer_ctc3/.decode_multi.py.swp and b/egs/tedlium2/ASR/conformer_ctc3/.decode_multi.py.swp differ diff --git a/egs/tedlium2/ASR/conformer_ctc3/decode_multi.py b/egs/tedlium2/ASR/conformer_ctc3/decode_multi.py index 2c4d8725c..ff13ff66c 100755 --- a/egs/tedlium2/ASR/conformer_ctc3/decode_multi.py +++ b/egs/tedlium2/ASR/conformer_ctc3/decode_multi.py @@ -778,6 +778,17 @@ def main() -> None: group_num=params.group_num, ) + # we need cut ids to display recognition results. + args.return_cuts = True + tedlium = TedLiumAsrDataModule(args) + + valid_cuts = tedlium.dev_cuts() + + valid_dl = tedlium.valid_dataloaders(valid_cuts) + + test_sets = ["dev"] + test_dls = [valid_dl] + for epoch in range(params.start, params.end+1): load_checkpoint(f"{params.exp_dir}/epoch-{epoch}.pt", model) @@ -786,17 +797,7 @@ def main() -> None: num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") - # we need cut ids to display recognition results. - args.return_cuts = True - tedlium = TedLiumAsrDataModule(args) - - valid_cuts = tedlium.dev_cuts() - - valid_dl = tedlium.valid_dataloaders(valid_cuts) - - test_sets = ["dev"] - test_dls = [dev_dl] - + for test_set, test_dl in zip(test_sets, test_dls): results_dict = decode_dataset( dl=test_dl,