From c5c35859ec44817e30fe820b182a14169d13fbbe Mon Sep 17 00:00:00 2001 From: Kinan Martin Date: Mon, 10 Mar 2025 10:12:45 +0900 Subject: [PATCH] implement musan and enable by default --- .../ASR/local/utils/asr_datamodule.py | 15 +++++++++++++-- egs/reazonspeech/ASR/zipformer/train.py | 9 ++++++++- 2 files changed, 21 insertions(+), 3 deletions(-) diff --git a/egs/reazonspeech/ASR/local/utils/asr_datamodule.py b/egs/reazonspeech/ASR/local/utils/asr_datamodule.py index e70370760..3bbd48a43 100644 --- a/egs/reazonspeech/ASR/local/utils/asr_datamodule.py +++ b/egs/reazonspeech/ASR/local/utils/asr_datamodule.py @@ -174,13 +174,16 @@ class ReazonSpeechAsrDataModule: group.add_argument( "--enable-musan", type=str2bool, - default=False, + default=True, help="When enabled, select noise from MUSAN and mix it" "with training dataset. ", ) def train_dataloaders( - self, cuts_train: CutSet, sampler_state_dict: Optional[Dict[str, Any]] = None + self, + cuts_train: CutSet, + sampler_state_dict: Optional[Dict[str, Any]] = None, + cuts_musan: Optional[CutSet] = None, ) -> DataLoader: """ Args: @@ -191,6 +194,14 @@ class ReazonSpeechAsrDataModule: """ transforms = [] + if cuts_musan is not None: + logging.info("Enable MUSAN") + transforms.append( + CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True) + ) + else: + logging.info("Disable MUSAN") + input_transforms = [] if self.args.enable_spec_aug: diff --git a/egs/reazonspeech/ASR/zipformer/train.py b/egs/reazonspeech/ASR/zipformer/train.py index 30bd3efba..54b4a9950 100755 --- a/egs/reazonspeech/ASR/zipformer/train.py +++ b/egs/reazonspeech/ASR/zipformer/train.py @@ -1219,8 +1219,15 @@ def run(rank, world_size, args): else: sampler_state_dict = None + if args.enable_musan: + cuts_musan = load_manifest(Path(args.manifest_dir) / "musan_cuts.jsonl.gz") + else: + cuts_musan = None + train_dl = reazonspeech_corpus.train_dataloaders( - train_cuts, sampler_state_dict=sampler_state_dict + train_cuts, + sampler_state_dict=sampler_state_dict, + cuts_musan=cuts_musan, ) valid_cuts = reazonspeech_corpus.valid_cuts()