diff --git a/egs/multi_zh_en/ASR/zipformer/multi_dataset.py b/egs/multi_zh_en/ASR/zipformer/multi_dataset.py index d8b66609c..1bc514d61 100644 --- a/egs/multi_zh_en/ASR/zipformer/multi_dataset.py +++ b/egs/multi_zh_en/ASR/zipformer/multi_dataset.py @@ -15,6 +15,7 @@ # limitations under the License. +import argparse import logging from functools import lru_cache from pathlib import Path @@ -24,14 +25,15 @@ from lhotse import CutSet, load_manifest_lazy class MultiDataset: - def __init__(self, fbank_dir: str): + def __init__(self, args: argparse.Namespace): """ Args: manifest_dir: It is expected to contain the following files: - aishell2_cuts_train.jsonl.gz """ - self.fbank_dir = Path(fbank_dir) + self.fbank_dir = Path(args.fbank_dir) + self.use_tal_csasr = args.use_tal_csasr def train_cuts(self) -> CutSet: logging.info("About to get multidataset train cuts") @@ -42,11 +44,33 @@ class MultiDataset: self.fbank_dir / "aishell2_cuts_train.jsonl.gz" ) + # TAL-CSASR + if self.use_tal_csasr: + logging.info("Loading TAL-CSASR in lazy mode") + tal_csasr_cuts = load_manifest_lazy( + self.fbank_dir / "tal_csasr_cuts_train_set.jsonl.gz" + ) + # LibriSpeech train_clean_100_cuts = self.train_clean_100_cuts() train_clean_360_cuts = self.train_clean_360_cuts() train_other_500_cuts = self.train_other_500_cuts() + if self.use_tal_csasr: + return CutSet.mux( + aishell_2_cuts, + train_clean_100_cuts, + train_clean_360_cuts, + train_other_500_cuts, + tal_csasr_cuts, + weights=[ + len(aishell_2_cuts), + len(train_clean_100_cuts), + len(train_clean_360_cuts), + len(train_other_500_cuts), + len(tal_csasr_cuts), + ], + ) return CutSet.mux( aishell_2_cuts, train_clean_100_cuts, @@ -99,7 +123,7 @@ class MultiDataset: # LibriSpeech test_clean_cuts = self.test_clean_cuts() test_other_cuts = self.test_other_cuts() - + logging.info("Loading TAL-CSASR set in lazy mode") tal_csasr_cuts = load_manifest_lazy( self.fbank_dir / "tal_csasr_cuts_test_set.jsonl.gz" diff --git a/egs/multi_zh_en/ASR/zipformer/train.py b/egs/multi_zh_en/ASR/zipformer/train.py index 8318fc1ee..0fcb7e51e 100755 --- a/egs/multi_zh_en/ASR/zipformer/train.py +++ b/egs/multi_zh_en/ASR/zipformer/train.py @@ -468,6 +468,13 @@ def get_parser(): help="Whether to use half precision training.", ) + parser.add_argument( + "--use-tal-csasr", + type=str2bool, + default=False, + help="Whether to use TAL-CSASR training data.", + ) + add_model_arguments(parser) return parser