mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-10 18:42:19 +00:00
fix decode script data module usage
This commit is contained in:
parent
ce894a7ba2
commit
6255ba5cb2
@ -157,14 +157,14 @@ def get_parser():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--bpe-model",
|
"--bpe-model",
|
||||||
type=str,
|
type=str,
|
||||||
default="data/lang_bbpe_2000/bbpe.model",
|
default="data/lang/bbpe_2000/bbpe.model",
|
||||||
help="Path to the BPE model",
|
help="Path to the BPE model",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--lang-dir",
|
"--lang-dir",
|
||||||
type=Path,
|
type=Path,
|
||||||
default="data/lang_bbpe_2000",
|
default="data/lang/bbpe_2000",
|
||||||
help="The lang dir containing word table and LG graph",
|
help="The lang dir containing word table and LG graph",
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -748,7 +748,7 @@ def main():
|
|||||||
|
|
||||||
# we need cut ids to display recognition results.
|
# we need cut ids to display recognition results.
|
||||||
args.return_cuts = True
|
args.return_cuts = True
|
||||||
data_module = MultiDatasetAsrDataModule(args)
|
multidataset_datamodule = MultiDatasetAsrDataModule(args)
|
||||||
multi_dataset = MultiDataset(args)
|
multi_dataset = MultiDataset(args)
|
||||||
|
|
||||||
def remove_short_utt(c: Cut):
|
def remove_short_utt(c: Cut):
|
||||||
@ -759,31 +759,42 @@ def main():
|
|||||||
)
|
)
|
||||||
return T > 0
|
return T > 0
|
||||||
|
|
||||||
test_sets_cuts = multi_dataset.test_cuts()
|
def tokenize_and_encode_text(c: Cut):
|
||||||
|
# Text normalize for each sample
|
||||||
|
text = c.supervisions[0].text
|
||||||
|
text = byte_encode(tokenize_by_ja_char(text))
|
||||||
|
c.supervisions[0].text = text
|
||||||
|
return c
|
||||||
|
|
||||||
test_sets = test_sets_cuts.keys()
|
test_cuts = multi_dataset.test_cuts()
|
||||||
test_dl = [
|
test_cuts = test_cuts.filter(remove_short_utt)
|
||||||
data_module.test_dataloaders(test_sets_cuts[cuts_name].filter(remove_short_utt))
|
# test_cuts = test_cuts.map(tokenize_and_encode_text)
|
||||||
for cuts_name in test_sets
|
|
||||||
]
|
|
||||||
|
|
||||||
for test_set, test_dl in zip(test_sets, test_dl):
|
test_dl = multidataset_datamodule.test_dataloaders(test_cuts)
|
||||||
logging.info(f"Start decoding test set: {test_set}")
|
|
||||||
|
|
||||||
results_dict = decode_dataset(
|
# test_sets = test_sets_cuts.keys()
|
||||||
dl=test_dl,
|
# test_dl = [
|
||||||
params=params,
|
# data_module.test_dataloaders(test_sets_cuts[cuts_name].filter(remove_short_utt))
|
||||||
model=model,
|
# for cuts_name in test_sets
|
||||||
sp=sp,
|
# ]
|
||||||
word_table=word_table,
|
|
||||||
decoding_graph=decoding_graph,
|
|
||||||
)
|
|
||||||
|
|
||||||
save_results(
|
# for test_set, test_dl in zip(test_sets, test_dl):
|
||||||
params=params,
|
logging.info("Start decoding test set")#: {test_set}")
|
||||||
test_set_name=test_set,
|
|
||||||
results_dict=results_dict,
|
results_dict = decode_dataset(
|
||||||
)
|
dl=test_dl,
|
||||||
|
params=params,
|
||||||
|
model=model,
|
||||||
|
sp=sp,
|
||||||
|
word_table=word_table,
|
||||||
|
decoding_graph=decoding_graph,
|
||||||
|
)
|
||||||
|
|
||||||
|
save_results(
|
||||||
|
params=params,
|
||||||
|
test_set_name="test_set",
|
||||||
|
results_dict=results_dict,
|
||||||
|
)
|
||||||
|
|
||||||
logging.info("Done!")
|
logging.info("Done!")
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user