support decoding

This commit is contained in:
marcoyang 2024-02-28 11:56:38 +08:00
parent b6f3a2b186
commit 7258271414
2 changed files with 23 additions and 12 deletions

View File

@ -431,4 +431,11 @@ class MLSAsrDataModule:
logging.info(f"About to get dev cuts for {language}") logging.info(f"About to get dev cuts for {language}")
return load_manifest_lazy( return load_manifest_lazy(
self.args.manifest_dir / f"mls-{language}_dev.jsonl.gz" self.args.manifest_dir / f"mls-{language}_dev.jsonl.gz"
)
@lru_cache()
def mls_test_cuts(self, language: str) -> CutSet:
logging.info(f"About to get test cuts for {language}")
return load_manifest_lazy(
self.args.manifest_dir / f"mls-{language}_test.jsonl.gz"
) )

View File

@ -1,7 +1,8 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
# #
# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang, # Copyright 2021-2024 Xiaomi Corporation (Author: Fangjun Kuang,
# Zengwei Yao) # Zengwei Yao,
# Xiaoyu Yang)
# #
# See ../../../../LICENSE for clarification regarding multiple authors # See ../../../../LICENSE for clarification regarding multiple authors
# #
@ -106,7 +107,7 @@ import k2
import sentencepiece as spm import sentencepiece as spm
import torch import torch
import torch.nn as nn import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule from asr_datamodule import MLSAsrDataModule
from beam_search import ( from beam_search import (
beam_search, beam_search,
fast_beam_search_nbest, fast_beam_search_nbest,
@ -755,7 +756,7 @@ def save_results(
@torch.no_grad() @torch.no_grad()
def main(): def main():
parser = get_parser() parser = get_parser()
LibriSpeechAsrDataModule.add_arguments(parser) MLSAsrDataModule.add_arguments(parser)
LmScorer.add_arguments(parser) LmScorer.add_arguments(parser)
args = parser.parse_args() args = parser.parse_args()
args.exp_dir = Path(args.exp_dir) args.exp_dir = Path(args.exp_dir)
@ -1014,18 +1015,21 @@ def main():
# we need cut ids to display recognition results. # we need cut ids to display recognition results.
args.return_cuts = True args.return_cuts = True
librispeech = LibriSpeechAsrDataModule(args) mls = MLSAsrDataModule(args)
test_clean_cuts = librispeech.test_clean_cuts() test_sets = []
test_other_cuts = librispeech.test_other_cuts() test_dls = []
test_clean_dl = librispeech.test_dataloaders(test_clean_cuts) test_languages = params.language.split(",")
test_other_dl = librispeech.test_dataloaders(test_other_cuts)
test_sets = ["test-clean", "test-other"] for language in test_languages:
test_dl = [test_clean_dl, test_other_dl] test_cuts = mls.mls_test_cuts(language)
test_dl = mls.test_dataloaders(test_cuts)
for test_set, test_dl in zip(test_sets, test_dl): test_sets.append(f"test-{language}")
test_dls.append(test_dl)
for test_set, test_dl in zip(test_sets, test_dls):
results_dict = decode_dataset( results_dict = decode_dataset(
dl=test_dl, dl=test_dl,
params=params, params=params,