From 725827141424f38603535a79b4f2f5b37c1684bb Mon Sep 17 00:00:00 2001 From: marcoyang Date: Wed, 28 Feb 2024 11:56:38 +0800 Subject: [PATCH] support decoding --- egs/mls/ASR/zipformer/asr_datamodule.py | 7 +++++++ egs/mls/ASR/zipformer/decode.py | 28 ++++++++++++++----------- 2 files changed, 23 insertions(+), 12 deletions(-) diff --git a/egs/mls/ASR/zipformer/asr_datamodule.py b/egs/mls/ASR/zipformer/asr_datamodule.py index d30983838..3acaa4ccc 100644 --- a/egs/mls/ASR/zipformer/asr_datamodule.py +++ b/egs/mls/ASR/zipformer/asr_datamodule.py @@ -431,4 +431,11 @@ class MLSAsrDataModule: logging.info(f"About to get dev cuts for {language}") return load_manifest_lazy( self.args.manifest_dir / f"mls-{language}_dev.jsonl.gz" + ) + + @lru_cache() + def mls_test_cuts(self, language: str) -> CutSet: + logging.info(f"About to get test cuts for {language}") + return load_manifest_lazy( + self.args.manifest_dir / f"mls-{language}_test.jsonl.gz" ) \ No newline at end of file diff --git a/egs/mls/ASR/zipformer/decode.py b/egs/mls/ASR/zipformer/decode.py index 339e253e6..ddf1a124c 100755 --- a/egs/mls/ASR/zipformer/decode.py +++ b/egs/mls/ASR/zipformer/decode.py @@ -1,7 +1,8 @@ #!/usr/bin/env python3 # -# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang, -# Zengwei Yao) +# Copyright 2021-2024 Xiaomi Corporation (Author: Fangjun Kuang, +# Zengwei Yao, +# Xiaoyu Yang) # # See ../../../../LICENSE for clarification regarding multiple authors # @@ -106,7 +107,7 @@ import k2 import sentencepiece as spm import torch import torch.nn as nn -from asr_datamodule import LibriSpeechAsrDataModule +from asr_datamodule import MLSAsrDataModule from beam_search import ( beam_search, fast_beam_search_nbest, @@ -755,7 +756,7 @@ def save_results( @torch.no_grad() def main(): parser = get_parser() - LibriSpeechAsrDataModule.add_arguments(parser) + MLSAsrDataModule.add_arguments(parser) LmScorer.add_arguments(parser) args = parser.parse_args() args.exp_dir = Path(args.exp_dir) @@ -1014,18 +1015,21 @@ def main(): # we need cut ids to display recognition results. args.return_cuts = True - librispeech = LibriSpeechAsrDataModule(args) + mls = MLSAsrDataModule(args) - test_clean_cuts = librispeech.test_clean_cuts() - test_other_cuts = librispeech.test_other_cuts() + test_sets = [] + test_dls = [] - test_clean_dl = librispeech.test_dataloaders(test_clean_cuts) - test_other_dl = librispeech.test_dataloaders(test_other_cuts) + test_languages = params.language.split(",") - test_sets = ["test-clean", "test-other"] - test_dl = [test_clean_dl, test_other_dl] + for language in test_languages: + test_cuts = mls.mls_test_cuts(language) + test_dl = mls.test_dataloaders(test_cuts) - for test_set, test_dl in zip(test_sets, test_dl): + test_sets.append(f"test-{language}") + test_dls.append(test_dl) + + for test_set, test_dl in zip(test_sets, test_dls): results_dict = decode_dataset( dl=test_dl, params=params,