From a9edd7cc3de72c289e59ff760652110e30953408 Mon Sep 17 00:00:00 2001 From: marcoyang Date: Tue, 27 Feb 2024 18:02:16 +0800 Subject: [PATCH] update the asr datamodule --- egs/mls/ASR/zipformer/asr_datamodule.py | 133 +++++++----------------- 1 file changed, 37 insertions(+), 96 deletions(-) diff --git a/egs/mls/ASR/zipformer/asr_datamodule.py b/egs/mls/ASR/zipformer/asr_datamodule.py index 1e16e8077..d30983838 100644 --- a/egs/mls/ASR/zipformer/asr_datamodule.py +++ b/egs/mls/ASR/zipformer/asr_datamodule.py @@ -84,19 +84,12 @@ class MLSAsrDataModule: ) group.add_argument( "--language", - type=str2bool, + type=str, default="all", - choices=["english", "german", "dutch", "french", "spanish", "italian", "portuguese", "polish", "all"], - help="""If all, use all the languages, other - use 960h LibriSpeech. Otherwise, use 100h subset.""", + # choices=["english", "german", "dutch", "french", "spanish", "italian", "portuguese", "polish", "all"], + help="""A list of languages separated by comma. If all, use all + the languages""", ) - group.add_argument( - "--mini-libri", - type=str2bool, - default=False, - help="True for mini librispeech", - ) - group.add_argument( "--manifest-dir", type=Path, @@ -405,89 +398,37 @@ class MLSAsrDataModule: return test_dl @lru_cache() - def train_clean_5_cuts(self) -> CutSet: - logging.info("mini_librispeech: About to get train-clean-5 cuts") + def train_mls_cuts(self) -> CutSet: + if self.args.language == "all": + languages = ["english", "german", "dutch", "french", "spanish", "italian", "portuguese", "polish",] + else: + languages = self.args.language.split(",") + if len(languages) == 1: + l = languages[0] + logging.info(f"About to get {l} cuts") + return load_manifest_lazy( + self.args.manifest_dir / f"mls-{l}_train.jsonl.gz" + ) + else: + all_cuts = [] + all_cuts_len = [] + for l in languages: + logging.info(f"About to get {l} cuts") + current_cuts = load_manifest_lazy( + self.args.manifest_dir / f"mls-{l}_train.jsonl.gz" + ) + current_cuts_len = len(current_cuts) + all_cuts.append(current_cuts) + all_cuts_len.append(current_cuts_len) + return CutSet.mux( + *all_cuts, + weights=all_cuts_len, + stop_early=True, + ) + + @lru_cache() + def mls_dev_cuts(self, language: str) -> CutSet: + logging.info(f"About to get dev cuts for {language}") return load_manifest_lazy( - self.args.manifest_dir / "librispeech_cuts_train-clean-5.jsonl.gz" - ) - - @lru_cache() - def train_clean_100_cuts(self) -> CutSet: - logging.info("About to get train-clean-100 cuts") - return load_manifest_lazy( - self.args.manifest_dir / "librispeech_cuts_train-clean-100.jsonl.gz" - ) - - @lru_cache() - def train_clean_360_cuts(self) -> CutSet: - logging.info("About to get train-clean-360 cuts") - return load_manifest_lazy( - self.args.manifest_dir / "librispeech_cuts_train-clean-360.jsonl.gz" - ) - - @lru_cache() - def train_other_500_cuts(self) -> CutSet: - logging.info("About to get train-other-500 cuts") - return load_manifest_lazy( - self.args.manifest_dir / "librispeech_cuts_train-other-500.jsonl.gz" - ) - - @lru_cache() - def train_all_shuf_cuts(self) -> CutSet: - logging.info( - "About to get the shuffled train-clean-100, \ - train-clean-360 and train-other-500 cuts" - ) - return load_manifest_lazy( - self.args.manifest_dir / "librispeech_cuts_train-all-shuf.jsonl.gz" - ) - - @lru_cache() - def dev_clean_2_cuts(self) -> CutSet: - logging.info("mini_librispeech: About to get dev-clean-2 cuts") - return load_manifest_lazy( - self.args.manifest_dir / "librispeech_cuts_dev-clean-2.jsonl.gz" - ) - - @lru_cache() - def dev_clean_cuts(self) -> CutSet: - logging.info("About to get dev-clean cuts") - return load_manifest_lazy( - self.args.manifest_dir / "librispeech_cuts_dev-clean.jsonl.gz" - ) - - @lru_cache() - def dev_other_cuts(self) -> CutSet: - logging.info("About to get dev-other cuts") - return load_manifest_lazy( - self.args.manifest_dir / "librispeech_cuts_dev-other.jsonl.gz" - ) - - @lru_cache() - def test_clean_cuts(self) -> CutSet: - logging.info("About to get test-clean cuts") - return load_manifest_lazy( - self.args.manifest_dir / "librispeech_cuts_test-clean.jsonl.gz" - ) - - @lru_cache() - def test_other_cuts(self) -> CutSet: - logging.info("About to get test-other cuts") - return load_manifest_lazy( - self.args.manifest_dir / "librispeech_cuts_test-other.jsonl.gz" - ) - - @lru_cache() - def gigaspeech_subset_small_cuts(self) -> CutSet: - logging.info("About to get Gigaspeech subset-S cuts") - return load_manifest_lazy(self.args.manifest_dir / "cuts_S.jsonl.gz") - - @lru_cache() - def gigaspeech_dev_cuts(self) -> CutSet: - logging.info("About to get Gigaspeech dev cuts") - return load_manifest_lazy(self.args.manifest_dir / "cuts_DEV.jsonl.gz") - - @lru_cache() - def gigaspeech_test_cuts(self) -> CutSet: - logging.info("About to get Gigaspeech test cuts") - return load_manifest_lazy(self.args.manifest_dir / "cuts_TEST.jsonl.gz") + self.args.manifest_dir / f"mls-{language}_dev.jsonl.gz" + ) \ No newline at end of file