From cd20e2155275954241b716d89134467d92d1f5b5 Mon Sep 17 00:00:00 2001 From: JinZr <60612200+JinZr@users.noreply.github.com> Date: Wed, 22 Nov 2023 17:24:21 +0800 Subject: [PATCH] minor updates on flags --- egs/multi_zh_en/ASR/zipformer/decode.py | 14 ++ .../ASR/zipformer/multi_dataset.py | 127 +++++++++++++----- egs/multi_zh_en/ASR/zipformer/train.py | 14 ++ 3 files changed, 119 insertions(+), 36 deletions(-) diff --git a/egs/multi_zh_en/ASR/zipformer/decode.py b/egs/multi_zh_en/ASR/zipformer/decode.py index dce1d9a37..a717df062 100755 --- a/egs/multi_zh_en/ASR/zipformer/decode.py +++ b/egs/multi_zh_en/ASR/zipformer/decode.py @@ -311,6 +311,20 @@ def get_parser(): help="Whether to use TAL-CSASR training data.", ) + parser.add_argument( + "--use-librispeech", + type=str2bool, + default=False, + help="Whether to use LibriSpeech training data.", + ) + + parser.add_argument( + "--use-aishell2", + type=str2bool, + default=False, + help="Whether to use Aishell-2 training data.", + ) + add_model_arguments(parser) return parser diff --git a/egs/multi_zh_en/ASR/zipformer/multi_dataset.py b/egs/multi_zh_en/ASR/zipformer/multi_dataset.py index c181c8af3..b159b16e1 100644 --- a/egs/multi_zh_en/ASR/zipformer/multi_dataset.py +++ b/egs/multi_zh_en/ASR/zipformer/multi_dataset.py @@ -34,15 +34,18 @@ class MultiDataset: """ self.fbank_dir = Path(args.manifest_dir) self.use_tal_csasr = args.use_tal_csasr + self.use_librispeech = args.use_librispeech + self.use_aishell2 = args.use_aishell2 def train_cuts(self) -> CutSet: logging.info("About to get multidataset train cuts") # AISHELL-2 - logging.info("Loading Aishell-2 in lazy mode") - aishell_2_cuts = load_manifest_lazy( - self.fbank_dir / "aishell2_cuts_train.jsonl.gz" - ) + if self.use_aishell2: + logging.info("Loading Aishell-2 in lazy mode") + aishell_2_cuts = load_manifest_lazy( + self.fbank_dir / "aishell2_cuts_train.jsonl.gz" + ) # TAL-CSASR if self.use_tal_csasr: @@ -52,11 +55,13 @@ class MultiDataset: ) # LibriSpeech - train_clean_100_cuts = self.train_clean_100_cuts() - train_clean_360_cuts = self.train_clean_360_cuts() - train_other_500_cuts = self.train_other_500_cuts() + if self.use_librispeech: + logging.info("Loading LibriSpeech in lazy mode") + train_clean_100_cuts = self.train_clean_100_cuts() + train_clean_360_cuts = self.train_clean_360_cuts() + train_other_500_cuts = self.train_other_500_cuts() - if self.use_tal_csasr: + if self.use_tal_csasr and self.use_librispeech and self.use_aishell2: return CutSet.mux( aishell_2_cuts, train_clean_100_cuts, @@ -71,18 +76,43 @@ class MultiDataset: len(tal_csasr_cuts), ], ) - return CutSet.mux( - aishell_2_cuts, - train_clean_100_cuts, - train_clean_360_cuts, - train_other_500_cuts, - weights=[ - len(aishell_2_cuts), - len(train_clean_100_cuts), - len(train_clean_360_cuts), - len(train_other_500_cuts), - ], - ) + elif not self.use_tal_csasr and self.use_librispeech and self.use_aishell2: + return CutSet.mux( + aishell_2_cuts, + train_clean_100_cuts, + train_clean_360_cuts, + train_other_500_cuts, + weights=[ + len(aishell_2_cuts), + len(train_clean_100_cuts), + len(train_clean_360_cuts), + len(train_other_500_cuts), + ], + ) + elif self.use_tal_csasr and not self.use_librispeech and self.use_aishell2: + return CutSet.mux( + aishell_2_cuts, + tal_csasr_cuts, + weights=[ + len(aishell_2_cuts), + len(tal_csasr_cuts), + ], + ) + elif self.use_tal_csasr and self.use_librispeech and not self.use_aishell2: + return CutSet.mux( + train_clean_100_cuts, + train_clean_360_cuts, + train_other_500_cuts, + tal_csasr_cuts, + weights=[ + len(train_clean_100_cuts), + len(train_clean_360_cuts), + len(train_other_500_cuts), + len(tal_csasr_cuts), + ], + ) + else: + raise NotImplementedError def dev_cuts(self) -> CutSet: logging.info("About to get multidataset dev cuts") @@ -97,14 +127,21 @@ class MultiDataset: dev_clean_cuts = self.dev_clean_cuts() dev_other_cuts = self.dev_other_cuts() + logging.info("Loading TAL-CSASR set in lazy mode") + tal_csasr_dev_cuts = load_manifest_lazy( + self.fbank_dir / "tal_csasr_cuts_dev_set.jsonl.gz" + ) + return CutSet.mux( aishell2_dev_cuts, dev_clean_cuts, dev_other_cuts, + tal_csasr_dev_cuts, weights=[ len(aishell2_dev_cuts), len(dev_clean_cuts), len(dev_other_cuts), + len(tal_csasr_dev_cuts), ], ) @@ -112,31 +149,49 @@ class MultiDataset: logging.info("About to get multidataset test cuts") # AISHELL-2 - logging.info("Loading Aishell-2 set in lazy mode") - aishell2_test_cuts = load_manifest_lazy( - self.fbank_dir / "aishell2_cuts_test.jsonl.gz" - ) - aishell2_dev_cuts = load_manifest_lazy( - self.fbank_dir / "aishell2_cuts_dev.jsonl.gz" - ) + if self.use_aishell2: + logging.info("Loading Aishell-2 set in lazy mode") + aishell2_test_cuts = load_manifest_lazy( + self.fbank_dir / "aishell2_cuts_test.jsonl.gz" + ) + aishell2_dev_cuts = load_manifest_lazy( + self.fbank_dir / "aishell2_cuts_dev.jsonl.gz" + ) # LibriSpeech - test_clean_cuts = self.test_clean_cuts() - test_other_cuts = self.test_other_cuts() + if self.use_librispeech: + test_clean_cuts = self.test_clean_cuts() + test_other_cuts = self.test_other_cuts() logging.info("Loading TAL-CSASR set in lazy mode") - tal_csasr_cuts = load_manifest_lazy( + tal_csasr_test_cuts = load_manifest_lazy( self.fbank_dir / "tal_csasr_cuts_test_set.jsonl.gz" ) + tal_csasr_dev_cuts = load_manifest_lazy( + self.fbank_dir / "tal_csasr_cuts_dev_set.jsonl.gz" + ) - return { - "aishell-2_test": aishell2_test_cuts, - "aishell-2_dev": aishell2_dev_cuts, - "librispeech_test_clean": test_clean_cuts, - "librispeech_test_other": test_other_cuts, - "tal_csasr_cuts_test": tal_csasr_cuts, + test_cuts = { + "tal_csasr_test": tal_csasr_test_cuts, + "tal_csasr_dev": tal_csasr_dev_cuts, } + if self.use_aishell2: + test_cuts.update( + { + "aishell-2_test": aishell2_test_cuts, + "aishell-2_dev": aishell2_dev_cuts, + } + ) + if self.use_librispeech: + test_cuts.update( + { + "librispeech_test_clean": test_clean_cuts, + "librispeech_test_other": test_other_cuts, + } + ) + return test_cuts + @lru_cache() def train_clean_100_cuts(self) -> CutSet: logging.info("About to get train-clean-100 cuts") diff --git a/egs/multi_zh_en/ASR/zipformer/train.py b/egs/multi_zh_en/ASR/zipformer/train.py index e1bf65da9..a44bfb6fa 100755 --- a/egs/multi_zh_en/ASR/zipformer/train.py +++ b/egs/multi_zh_en/ASR/zipformer/train.py @@ -475,6 +475,20 @@ def get_parser(): help="Whether to use TAL-CSASR training data.", ) + parser.add_argument( + "--use-librispeech", + type=str2bool, + default=False, + help="Whether to use LibriSpeech training data.", + ) + + parser.add_argument( + "--use-aishell2", + type=str2bool, + default=False, + help="Whether to use Aishell-2 training data.", + ) + add_model_arguments(parser) return parser