From 6255ba5cb202bcc8bd1bdb54f3d52cacc5d8c6f7 Mon Sep 17 00:00:00 2001 From: Kinan Martin Date: Fri, 6 Jun 2025 11:29:29 +0900 Subject: [PATCH] fix decode script data module usage --- egs/multi_ja_en/ASR/zipformer/decode.py | 59 +++++++++++++++---------- 1 file changed, 35 insertions(+), 24 deletions(-) diff --git a/egs/multi_ja_en/ASR/zipformer/decode.py b/egs/multi_ja_en/ASR/zipformer/decode.py index 37cf39ddd..a3aeb78fa 100755 --- a/egs/multi_ja_en/ASR/zipformer/decode.py +++ b/egs/multi_ja_en/ASR/zipformer/decode.py @@ -157,14 +157,14 @@ def get_parser(): parser.add_argument( "--bpe-model", type=str, - default="data/lang_bbpe_2000/bbpe.model", + default="data/lang/bbpe_2000/bbpe.model", help="Path to the BPE model", ) parser.add_argument( "--lang-dir", type=Path, - default="data/lang_bbpe_2000", + default="data/lang/bbpe_2000", help="The lang dir containing word table and LG graph", ) @@ -748,7 +748,7 @@ def main(): # we need cut ids to display recognition results. args.return_cuts = True - data_module = MultiDatasetAsrDataModule(args) + multidataset_datamodule = MultiDatasetAsrDataModule(args) multi_dataset = MultiDataset(args) def remove_short_utt(c: Cut): @@ -759,31 +759,42 @@ def main(): ) return T > 0 - test_sets_cuts = multi_dataset.test_cuts() + def tokenize_and_encode_text(c: Cut): + # Text normalize for each sample + text = c.supervisions[0].text + text = byte_encode(tokenize_by_ja_char(text)) + c.supervisions[0].text = text + return c - test_sets = test_sets_cuts.keys() - test_dl = [ - data_module.test_dataloaders(test_sets_cuts[cuts_name].filter(remove_short_utt)) - for cuts_name in test_sets - ] + test_cuts = multi_dataset.test_cuts() + test_cuts = test_cuts.filter(remove_short_utt) + # test_cuts = test_cuts.map(tokenize_and_encode_text) - for test_set, test_dl in zip(test_sets, test_dl): - logging.info(f"Start decoding test set: {test_set}") + test_dl = multidataset_datamodule.test_dataloaders(test_cuts) - results_dict = decode_dataset( - dl=test_dl, - params=params, - model=model, - sp=sp, - word_table=word_table, - decoding_graph=decoding_graph, - ) + # test_sets = test_sets_cuts.keys() + # test_dl = [ + # data_module.test_dataloaders(test_sets_cuts[cuts_name].filter(remove_short_utt)) + # for cuts_name in test_sets + # ] - save_results( - params=params, - test_set_name=test_set, - results_dict=results_dict, - ) + # for test_set, test_dl in zip(test_sets, test_dl): + logging.info("Start decoding test set")#: {test_set}") + + results_dict = decode_dataset( + dl=test_dl, + params=params, + model=model, + sp=sp, + word_table=word_table, + decoding_graph=decoding_graph, + ) + + save_results( + params=params, + test_set_name="test_set", + results_dict=results_dict, + ) logging.info("Done!")