diff --git a/egs/tedlium2/ASR/conformer_ctc3/.decode_multi.py.swp b/egs/tedlium2/ASR/conformer_ctc3/.decode_multi.py.swp index 4622dd2f1..cbaf71bb0 100644 Binary files a/egs/tedlium2/ASR/conformer_ctc3/.decode_multi.py.swp and b/egs/tedlium2/ASR/conformer_ctc3/.decode_multi.py.swp differ diff --git a/egs/tedlium2/ASR/conformer_ctc3/decode_multi.py b/egs/tedlium2/ASR/conformer_ctc3/decode_multi.py index 030c602cf..2c4d8725c 100755 --- a/egs/tedlium2/ASR/conformer_ctc3/decode_multi.py +++ b/egs/tedlium2/ASR/conformer_ctc3/decode_multi.py @@ -777,121 +777,43 @@ def main() -> None: num_decoder_layers=params.num_decoder_layers, group_num=params.group_num, ) + + for epoch in range(params.start, params.end+1): + load_checkpoint(f"{params.exp_dir}/epoch-{epoch}.pt", model) + + model.to(device) + model.eval() + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") - if not params.use_averaged_model: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - elif params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if i >= 1: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - else: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg + 1: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - filename_start = filenames[-1] - filename_end = filenames[0] - logging.info( - "Calculating the averaged model over iteration checkpoints" - f" from {filename_start} (excluded) to {filename_end}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - else: - assert params.avg > 0, params.avg - start = params.epoch - params.avg - assert start >= 1, start - filename_start = f"{params.exp_dir}/epoch-{start}.pt" - filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" - logging.info( - f"Calculating the averaged model over epoch range from " - f"{start} (excluded) to {params.epoch}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) + # we need cut ids to display recognition results. + args.return_cuts = True + tedlium = TedLiumAsrDataModule(args) + + valid_cuts = tedlium.dev_cuts() + + valid_dl = tedlium.valid_dataloaders(valid_cuts) + + test_sets = ["dev"] + test_dls = [dev_dl] + + for test_set, test_dl in zip(test_sets, test_dls): + results_dict = decode_dataset( + dl=test_dl, + params=params, + model=model, + HLG=HLG, + H=H, + bpe_model=bpe_model, + word_table=lexicon.word_table, + G=G, + sos_id=sos_id, + eos_id=eos_id, ) - model.to(device) - model.eval() - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") + save_results(params=params, test_set_name=test_set, results_dict=results_dict) - # we need cut ids to display recognition results. - args.return_cuts = True - tedlium = TedLiumAsrDataModule(args) - - valid_cuts = tedlium.dev_cuts() - test_cuts = tedlium.test_cuts() - - valid_dl = tedlium.valid_dataloaders(valid_cuts) - test_dl = tedlium.test_dataloaders(test_cuts) - - #test_sets = ["dev", "test"] - #test_dls = [valid_dl, test_dl] - test_sets = ["dev"] - test_dls = [dev_dl] - - for test_set, test_dl in zip(test_sets, test_dls): - results_dict = decode_dataset( - dl=test_dl, - params=params, - model=model, - HLG=HLG, - H=H, - bpe_model=bpe_model, - word_table=lexicon.word_table, - G=G, - sos_id=sos_id, - eos_id=eos_id, - ) - - save_results(params=params, test_set_name=test_set, results_dict=results_dict) - - logging.info("Done!") + logging.info("Done!") torch.set_num_threads(1)