fix decode script data module usage

This commit is contained in:
Kinan Martin 2025-06-06 11:29:29 +09:00
parent ce894a7ba2
commit 6255ba5cb2

View File

@ -157,14 +157,14 @@ def get_parser():
parser.add_argument(
"--bpe-model",
type=str,
default="data/lang_bbpe_2000/bbpe.model",
default="data/lang/bbpe_2000/bbpe.model",
help="Path to the BPE model",
)
parser.add_argument(
"--lang-dir",
type=Path,
default="data/lang_bbpe_2000",
default="data/lang/bbpe_2000",
help="The lang dir containing word table and LG graph",
)
@ -748,7 +748,7 @@ def main():
# we need cut ids to display recognition results.
args.return_cuts = True
data_module = MultiDatasetAsrDataModule(args)
multidataset_datamodule = MultiDatasetAsrDataModule(args)
multi_dataset = MultiDataset(args)
def remove_short_utt(c: Cut):
@ -759,31 +759,42 @@ def main():
)
return T > 0
test_sets_cuts = multi_dataset.test_cuts()
def tokenize_and_encode_text(c: Cut):
# Text normalize for each sample
text = c.supervisions[0].text
text = byte_encode(tokenize_by_ja_char(text))
c.supervisions[0].text = text
return c
test_sets = test_sets_cuts.keys()
test_dl = [
data_module.test_dataloaders(test_sets_cuts[cuts_name].filter(remove_short_utt))
for cuts_name in test_sets
]
test_cuts = multi_dataset.test_cuts()
test_cuts = test_cuts.filter(remove_short_utt)
# test_cuts = test_cuts.map(tokenize_and_encode_text)
for test_set, test_dl in zip(test_sets, test_dl):
logging.info(f"Start decoding test set: {test_set}")
test_dl = multidataset_datamodule.test_dataloaders(test_cuts)
results_dict = decode_dataset(
dl=test_dl,
params=params,
model=model,
sp=sp,
word_table=word_table,
decoding_graph=decoding_graph,
)
# test_sets = test_sets_cuts.keys()
# test_dl = [
# data_module.test_dataloaders(test_sets_cuts[cuts_name].filter(remove_short_utt))
# for cuts_name in test_sets
# ]
save_results(
params=params,
test_set_name=test_set,
results_dict=results_dict,
)
# for test_set, test_dl in zip(test_sets, test_dl):
logging.info("Start decoding test set")#: {test_set}")
results_dict = decode_dataset(
dl=test_dl,
params=params,
model=model,
sp=sp,
word_table=word_table,
decoding_graph=decoding_graph,
)
save_results(
params=params,
test_set_name="test_set",
results_dict=results_dict,
)
logging.info("Done!")