mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 10:02:22 +00:00
fix decode script data module usage
This commit is contained in:
parent
ce894a7ba2
commit
6255ba5cb2
@ -157,14 +157,14 @@ def get_parser():
|
||||
parser.add_argument(
|
||||
"--bpe-model",
|
||||
type=str,
|
||||
default="data/lang_bbpe_2000/bbpe.model",
|
||||
default="data/lang/bbpe_2000/bbpe.model",
|
||||
help="Path to the BPE model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--lang-dir",
|
||||
type=Path,
|
||||
default="data/lang_bbpe_2000",
|
||||
default="data/lang/bbpe_2000",
|
||||
help="The lang dir containing word table and LG graph",
|
||||
)
|
||||
|
||||
@ -748,7 +748,7 @@ def main():
|
||||
|
||||
# we need cut ids to display recognition results.
|
||||
args.return_cuts = True
|
||||
data_module = MultiDatasetAsrDataModule(args)
|
||||
multidataset_datamodule = MultiDatasetAsrDataModule(args)
|
||||
multi_dataset = MultiDataset(args)
|
||||
|
||||
def remove_short_utt(c: Cut):
|
||||
@ -759,31 +759,42 @@ def main():
|
||||
)
|
||||
return T > 0
|
||||
|
||||
test_sets_cuts = multi_dataset.test_cuts()
|
||||
def tokenize_and_encode_text(c: Cut):
|
||||
# Text normalize for each sample
|
||||
text = c.supervisions[0].text
|
||||
text = byte_encode(tokenize_by_ja_char(text))
|
||||
c.supervisions[0].text = text
|
||||
return c
|
||||
|
||||
test_sets = test_sets_cuts.keys()
|
||||
test_dl = [
|
||||
data_module.test_dataloaders(test_sets_cuts[cuts_name].filter(remove_short_utt))
|
||||
for cuts_name in test_sets
|
||||
]
|
||||
test_cuts = multi_dataset.test_cuts()
|
||||
test_cuts = test_cuts.filter(remove_short_utt)
|
||||
# test_cuts = test_cuts.map(tokenize_and_encode_text)
|
||||
|
||||
for test_set, test_dl in zip(test_sets, test_dl):
|
||||
logging.info(f"Start decoding test set: {test_set}")
|
||||
test_dl = multidataset_datamodule.test_dataloaders(test_cuts)
|
||||
|
||||
results_dict = decode_dataset(
|
||||
dl=test_dl,
|
||||
params=params,
|
||||
model=model,
|
||||
sp=sp,
|
||||
word_table=word_table,
|
||||
decoding_graph=decoding_graph,
|
||||
)
|
||||
# test_sets = test_sets_cuts.keys()
|
||||
# test_dl = [
|
||||
# data_module.test_dataloaders(test_sets_cuts[cuts_name].filter(remove_short_utt))
|
||||
# for cuts_name in test_sets
|
||||
# ]
|
||||
|
||||
save_results(
|
||||
params=params,
|
||||
test_set_name=test_set,
|
||||
results_dict=results_dict,
|
||||
)
|
||||
# for test_set, test_dl in zip(test_sets, test_dl):
|
||||
logging.info("Start decoding test set")#: {test_set}")
|
||||
|
||||
results_dict = decode_dataset(
|
||||
dl=test_dl,
|
||||
params=params,
|
||||
model=model,
|
||||
sp=sp,
|
||||
word_table=word_table,
|
||||
decoding_graph=decoding_graph,
|
||||
)
|
||||
|
||||
save_results(
|
||||
params=params,
|
||||
test_set_name="test_set",
|
||||
results_dict=results_dict,
|
||||
)
|
||||
|
||||
logging.info("Done!")
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user