fix decode script data module usage

This commit is contained in:
Kinan Martin 2025-06-06 11:29:29 +09:00
parent ce894a7ba2
commit 6255ba5cb2

View File

@ -157,14 +157,14 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--bpe-model", "--bpe-model",
type=str, type=str,
default="data/lang_bbpe_2000/bbpe.model", default="data/lang/bbpe_2000/bbpe.model",
help="Path to the BPE model", help="Path to the BPE model",
) )
parser.add_argument( parser.add_argument(
"--lang-dir", "--lang-dir",
type=Path, type=Path,
default="data/lang_bbpe_2000", default="data/lang/bbpe_2000",
help="The lang dir containing word table and LG graph", help="The lang dir containing word table and LG graph",
) )
@ -748,7 +748,7 @@ def main():
# we need cut ids to display recognition results. # we need cut ids to display recognition results.
args.return_cuts = True args.return_cuts = True
data_module = MultiDatasetAsrDataModule(args) multidataset_datamodule = MultiDatasetAsrDataModule(args)
multi_dataset = MultiDataset(args) multi_dataset = MultiDataset(args)
def remove_short_utt(c: Cut): def remove_short_utt(c: Cut):
@ -759,31 +759,42 @@ def main():
) )
return T > 0 return T > 0
test_sets_cuts = multi_dataset.test_cuts() def tokenize_and_encode_text(c: Cut):
# Text normalize for each sample
text = c.supervisions[0].text
text = byte_encode(tokenize_by_ja_char(text))
c.supervisions[0].text = text
return c
test_sets = test_sets_cuts.keys() test_cuts = multi_dataset.test_cuts()
test_dl = [ test_cuts = test_cuts.filter(remove_short_utt)
data_module.test_dataloaders(test_sets_cuts[cuts_name].filter(remove_short_utt)) # test_cuts = test_cuts.map(tokenize_and_encode_text)
for cuts_name in test_sets
]
for test_set, test_dl in zip(test_sets, test_dl): test_dl = multidataset_datamodule.test_dataloaders(test_cuts)
logging.info(f"Start decoding test set: {test_set}")
results_dict = decode_dataset( # test_sets = test_sets_cuts.keys()
dl=test_dl, # test_dl = [
params=params, # data_module.test_dataloaders(test_sets_cuts[cuts_name].filter(remove_short_utt))
model=model, # for cuts_name in test_sets
sp=sp, # ]
word_table=word_table,
decoding_graph=decoding_graph,
)
save_results( # for test_set, test_dl in zip(test_sets, test_dl):
params=params, logging.info("Start decoding test set")#: {test_set}")
test_set_name=test_set,
results_dict=results_dict, results_dict = decode_dataset(
) dl=test_dl,
params=params,
model=model,
sp=sp,
word_table=word_table,
decoding_graph=decoding_graph,
)
save_results(
params=params,
test_set_name="test_set",
results_dict=results_dict,
)
logging.info("Done!") logging.info("Done!")