mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-10 10:32:17 +00:00
support decoding
This commit is contained in:
parent
b6f3a2b186
commit
7258271414
@ -432,3 +432,10 @@ class MLSAsrDataModule:
|
|||||||
return load_manifest_lazy(
|
return load_manifest_lazy(
|
||||||
self.args.manifest_dir / f"mls-{language}_dev.jsonl.gz"
|
self.args.manifest_dir / f"mls-{language}_dev.jsonl.gz"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@lru_cache()
|
||||||
|
def mls_test_cuts(self, language: str) -> CutSet:
|
||||||
|
logging.info(f"About to get test cuts for {language}")
|
||||||
|
return load_manifest_lazy(
|
||||||
|
self.args.manifest_dir / f"mls-{language}_test.jsonl.gz"
|
||||||
|
)
|
@ -1,7 +1,8 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
#
|
#
|
||||||
# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang,
|
# Copyright 2021-2024 Xiaomi Corporation (Author: Fangjun Kuang,
|
||||||
# Zengwei Yao)
|
# Zengwei Yao,
|
||||||
|
# Xiaoyu Yang)
|
||||||
#
|
#
|
||||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
#
|
#
|
||||||
@ -106,7 +107,7 @@ import k2
|
|||||||
import sentencepiece as spm
|
import sentencepiece as spm
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from asr_datamodule import LibriSpeechAsrDataModule
|
from asr_datamodule import MLSAsrDataModule
|
||||||
from beam_search import (
|
from beam_search import (
|
||||||
beam_search,
|
beam_search,
|
||||||
fast_beam_search_nbest,
|
fast_beam_search_nbest,
|
||||||
@ -755,7 +756,7 @@ def save_results(
|
|||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def main():
|
def main():
|
||||||
parser = get_parser()
|
parser = get_parser()
|
||||||
LibriSpeechAsrDataModule.add_arguments(parser)
|
MLSAsrDataModule.add_arguments(parser)
|
||||||
LmScorer.add_arguments(parser)
|
LmScorer.add_arguments(parser)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
args.exp_dir = Path(args.exp_dir)
|
args.exp_dir = Path(args.exp_dir)
|
||||||
@ -1014,18 +1015,21 @@ def main():
|
|||||||
|
|
||||||
# we need cut ids to display recognition results.
|
# we need cut ids to display recognition results.
|
||||||
args.return_cuts = True
|
args.return_cuts = True
|
||||||
librispeech = LibriSpeechAsrDataModule(args)
|
mls = MLSAsrDataModule(args)
|
||||||
|
|
||||||
test_clean_cuts = librispeech.test_clean_cuts()
|
test_sets = []
|
||||||
test_other_cuts = librispeech.test_other_cuts()
|
test_dls = []
|
||||||
|
|
||||||
test_clean_dl = librispeech.test_dataloaders(test_clean_cuts)
|
test_languages = params.language.split(",")
|
||||||
test_other_dl = librispeech.test_dataloaders(test_other_cuts)
|
|
||||||
|
|
||||||
test_sets = ["test-clean", "test-other"]
|
for language in test_languages:
|
||||||
test_dl = [test_clean_dl, test_other_dl]
|
test_cuts = mls.mls_test_cuts(language)
|
||||||
|
test_dl = mls.test_dataloaders(test_cuts)
|
||||||
|
|
||||||
for test_set, test_dl in zip(test_sets, test_dl):
|
test_sets.append(f"test-{language}")
|
||||||
|
test_dls.append(test_dl)
|
||||||
|
|
||||||
|
for test_set, test_dl in zip(test_sets, test_dls):
|
||||||
results_dict = decode_dataset(
|
results_dict = decode_dataset(
|
||||||
dl=test_dl,
|
dl=test_dl,
|
||||||
params=params,
|
params=params,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user