diff --git a/egs/multi_zh_en/ASR/zipformer/multi_dataset.py b/egs/multi_zh_en/ASR/zipformer/multi_dataset.py index 7a357c83d..435d76ddb 100644 --- a/egs/multi_zh_en/ASR/zipformer/multi_dataset.py +++ b/egs/multi_zh_en/ASR/zipformer/multi_dataset.py @@ -15,13 +15,11 @@ # limitations under the License. -import glob import logging -import re +from functools import lru_cache from pathlib import Path -from typing import Dict, List +from typing import Dict -import lhotse from lhotse import CutSet, load_manifest_lazy @@ -44,14 +42,25 @@ class MultiDataset: self.fbank_dir / "aishell2_cuts_train.jsonl.gz" ) + # LibriSpeech + train_clean_100_cuts = self.train_clean_100_cuts() + train_clean_360_cuts = self.train_clean_360_cuts() + train_other_500_cuts = self.train_other_500_cuts() + return CutSet.mux( aishell_2_cuts, + train_clean_100_cuts, + train_clean_360_cuts, + train_other_500_cuts, weights=[ len(aishell_2_cuts), + len(train_clean_100_cuts), + len(train_clean_360_cuts), + len(train_other_500_cuts), ], ) - def dev_cuts(self) -> List[CutSet]: + def dev_cuts(self) -> CutSet: logging.info("About to get multidataset dev cuts") # AISHELL-2 @@ -60,9 +69,20 @@ class MultiDataset: self.fbank_dir / "aishell2_cuts_dev.jsonl.gz" ) - return [ + # LibriSpeech + dev_clean_cuts = self.dev_clean_cuts() + dev_other_cuts = self.dev_other_cuts() + + return CutSet.mux( aishell2_dev_cuts, - ] + dev_clean_cuts, + dev_other_cuts, + weights=[ + len(aishell2_dev_cuts), + len(dev_clean_cuts), + len(dev_other_cuts), + ], + ) def test_cuts(self) -> Dict[str, CutSet]: logging.info("About to get multidataset test cuts") @@ -76,7 +96,62 @@ class MultiDataset: self.fbank_dir / "aishell2_cuts_dev.jsonl.gz" ) + # LibriSpeech + test_clean_cuts = self.test_clean_cuts() + test_other_cuts = self.test_other_cuts() + return { "aishell-2_test": aishell2_test_cuts, "aishell-2_dev": aishell2_dev_cuts, + "librispeech_test_clean": test_clean_cuts, + "librispeech_test_other": test_other_cuts, } + + @lru_cache() + def train_clean_100_cuts(self) -> CutSet: + logging.info("About to get train-clean-100 cuts") + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_train-clean-100.jsonl.gz" + ) + + @lru_cache() + def train_clean_360_cuts(self) -> CutSet: + logging.info("About to get train-clean-360 cuts") + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_train-clean-360.jsonl.gz" + ) + + @lru_cache() + def train_other_500_cuts(self) -> CutSet: + logging.info("About to get train-other-500 cuts") + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_train-other-500.jsonl.gz" + ) + + @lru_cache() + def dev_clean_cuts(self) -> CutSet: + logging.info("About to get dev-clean cuts") + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_dev-clean.jsonl.gz" + ) + + @lru_cache() + def dev_other_cuts(self) -> CutSet: + logging.info("About to get dev-other cuts") + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_dev-other.jsonl.gz" + ) + + @lru_cache() + def test_clean_cuts(self) -> CutSet: + logging.info("About to get test-clean cuts") + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_test-clean.jsonl.gz" + ) + + @lru_cache() + def test_other_cuts(self) -> CutSet: + logging.info("About to get test-other cuts") + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_test-other.jsonl.gz" + )