From 7310489dc94b9a120120492dfe1c13ed5c9b8f50 Mon Sep 17 00:00:00 2001 From: wd929 Date: Wed, 27 Dec 2023 17:00:45 +0800 Subject: [PATCH] Support training for aec-iva and gss data --- egs/icmcasr/ASR/zipformer/asr_datamodule.py | 58 ++++++++++++++------- egs/icmcasr/ASR/zipformer/train.py | 5 +- 2 files changed, 44 insertions(+), 19 deletions(-) diff --git a/egs/icmcasr/ASR/zipformer/asr_datamodule.py b/egs/icmcasr/ASR/zipformer/asr_datamodule.py index cd6d8ea7f..c935cf7a8 100644 --- a/egs/icmcasr/ASR/zipformer/asr_datamodule.py +++ b/egs/icmcasr/ASR/zipformer/asr_datamodule.py @@ -101,6 +101,12 @@ class ICMCAsrDataModule: default=Path("data/fbank"), help="Path to directory with train/valid/test cuts.", ) + group.add_argument( + "--manifest-aec-iva-dir", + type=Path, + default=Path("data/fbank"), + help="Path to directory with aec iva data train/valid/test cuts.", + ) group.add_argument( "--max-duration", type=int, @@ -400,51 +406,67 @@ class ICMCAsrDataModule: ) return test_dl - @lru_cache() def train_ihm_cuts(self) -> CutSet: logging.info("About to get train-ihm cuts") return load_manifest_lazy( self.args.manifest_dir / "cuts_train_ihm.jsonl.gz" ) - @lru_cache() def train_ihm_rvb_cuts(self) -> CutSet: logging.info("About to get train-ihm-rvb cuts") return load_manifest_lazy( self.args.manifest_dir / "cuts_train_ihm_rvb.jsonl.gz" ) - @lru_cache() def train_shm_cuts(self) -> CutSet: logging.info("About to get train-shm cuts") return load_manifest_lazy( self.args.manifest_dir / "cuts_train_sdm.jsonl.gz" ) + def train_gss_cuts(self) -> CutSet: + logging.info("About to get train-gss cuts") + return load_manifest_lazy( + self.args.manifest_dir / "cuts_train_gss.jsonl.gz" + ) + + def train_aec_iva_cuts(self) -> CutSet: + logging.info("About to get train-aec_iva cuts") + return load_manifest_lazy( + self.args.manifest_aec_iva_dir / "icmcasr-aec-iva_cuts_train_aec_iva.jsonl.gz" - @lru_cache() def dev_ihm_cuts(self) -> CutSet: logging.info("About to get dev-ihm cuts") return load_manifest_lazy( self.args.manifest_dir / "cuts_dev_ihm.jsonl.gz" ) - @lru_cache() def dev_shm_cuts(self) -> CutSet: logging.info("About to get dev-other cuts") return load_manifest_lazy( self.args.manifest_dir / "cuts_dev_sdm.jsonl.gz" ) + def dev_aec_iva_cuts(self) -> CutSet: + logging.info("About to get aec iva dev cuts") + return load_manifest_lazy( + self.args.manifest_aec_iva_dir / "icmcasr-aec-iva_cuts_train_aec_iva.jsonl.gz" + ) + def dev_gss_cuts(self) -> CutSet: + logging.info("About to get dev-gss cuts") + return load_manifest_lazy( + self.args.manifest_dir / "cuts_dev_gss.jsonl.gz" + ) + def test_eval_track1_gss(self) -> CutSet: + logging.info("About to get eval1-gss cuts") + return load_manifest_lazy( + self.args.manifest_dir / "cuts_eval_track1_gss.jsonl.gz" + ) + def test_eval_track1_sdm(self) -> CutSet: + logging.info("About to get eval1-sdm cuts") + return load_manifest_lazy( + self.args.manifest_dir / "cuts_eval_track1_sdm.jsonl.gz" + ) + def test_eval_track1_aec_iva(self) -> CutSet: + logging.info("About to get aec iva dev cuts") + return load_manifest_lazy( + self.args.manifest_aec_iva_dir / "cuts.jsonl.gz" - # @lru_cache() - # def test_clean_cuts(self) -> CutSet: - # logging.info("About to get test-clean cuts") - # return load_manifest_lazy( - # self.args.manifest_dir / "librispeech_cuts_test-clean.jsonl.gz" - # ) - - # @lru_cache() - # def test_other_cuts(self) -> CutSet: - # logging.info("About to get test-other cuts") - # return load_manifest_lazy( - # self.args.manifest_dir / "librispeech_cuts_test-other.jsonl.gz" - # ) diff --git a/egs/icmcasr/ASR/zipformer/train.py b/egs/icmcasr/ASR/zipformer/train.py index 413b3c661..916ae441b 100755 --- a/egs/icmcasr/ASR/zipformer/train.py +++ b/egs/icmcasr/ASR/zipformer/train.py @@ -1179,6 +1179,8 @@ def run(rank, world_size, args): if params.full_data: train_cuts += icmc.train_ihm_rvb_cuts() train_cuts += icmc.train_shm_cuts() + train_cuts += icmc.train_aec_iva_cuts() + train_cuts += icmc.train_gss_cuts() def remove_short_and_long_utt(c: Cut): # Keep only utterances with duration between 1 second and 20 seconds @@ -1231,7 +1233,8 @@ def run(rank, world_size, args): ) valid_cuts = icmc.dev_ihm_cuts() - # valid_cuts += librispeech.dev_other_cuts() + valid_cuts += icmc.dev_gss_cuts() + valid_cuts += icmc.dev_aec_iva_cuts() valid_dl = icmc.valid_dataloaders(valid_cuts) if not params.print_diagnostics: