From 6665469a4e23b4698017c43369dc17d121840a64 Mon Sep 17 00:00:00 2001 From: marcoyang Date: Fri, 5 Jan 2024 16:12:24 +0800 Subject: [PATCH] add args in asr datamodule --- .../asr_datamodule.py | 25 ++++++++++++++++--- 1 file changed, 22 insertions(+), 3 deletions(-) diff --git a/egs/spgispeech/ASR/pruned_transducer_stateless2/asr_datamodule.py b/egs/spgispeech/ASR/pruned_transducer_stateless2/asr_datamodule.py index cf70fc0f8..8b6a5a7e0 100644 --- a/egs/spgispeech/ASR/pruned_transducer_stateless2/asr_datamodule.py +++ b/egs/spgispeech/ASR/pruned_transducer_stateless2/asr_datamodule.py @@ -102,6 +102,20 @@ class SPGISpeechAsrDataModule: help="Determines the maximum duration of a concatenated cut " "relative to the duration of the longest cut in a batch.", ) + group.add_argument( + "--drop-last", + type=str2bool, + default=False, + help="When enabled, the last batch will be dropped", + ) + group.add_argument( + "--return-cuts", + type=str2bool, + default=True, + help="When enabled, each batch will have the " + "field: batch['supervisions']['cut'] with the cuts that " + "were used to construct it.", + ) group.add_argument( "--gap", type=float, @@ -143,7 +157,7 @@ class SPGISpeechAsrDataModule: group.add_argument( "--num-workers", type=int, - default=8, + default=2, help="The number of training dataloader workers that " "collect the batches.", ) @@ -176,7 +190,7 @@ class SPGISpeechAsrDataModule: The state dict for the training sampler. """ logging.info("About to get Musan cuts") - cuts_musan = load_manifest(self.args.manifest_dir / "cuts_musan.jsonl.gz") + cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") transforms = [] if self.args.enable_musan: @@ -223,11 +237,13 @@ class SPGISpeechAsrDataModule: cut_transforms=transforms, input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), input_transforms=input_transforms, + return_cuts=self.args.return_cuts, ) else: train = K2SpeechRecognitionDataset( cut_transforms=transforms, input_transforms=input_transforms, + return_cuts=self.args.return_cuts, ) logging.info("Using DynamicBucketingSampler.") @@ -236,7 +252,7 @@ class SPGISpeechAsrDataModule: max_duration=self.args.max_duration, shuffle=False, num_buckets=self.args.num_buckets, - drop_last=True, + drop_last=self.args.drop_last, ) logging.info("About to create train dataloader") @@ -274,10 +290,12 @@ class SPGISpeechAsrDataModule: validate = K2SpeechRecognitionDataset( cut_transforms=transforms, input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), + return_cuts=self.args.return_cuts, ) else: validate = K2SpeechRecognitionDataset( cut_transforms=transforms, + return_cuts=self.args.return_cuts, ) valid_sampler = DynamicBucketingSampler( cuts_valid, @@ -301,6 +319,7 @@ class SPGISpeechAsrDataModule: input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))) if self.args.on_the_fly_feats else PrecomputedFeatures(), + return_cuts=self.args.return_cuts, ) sampler = DynamicBucketingSampler( cuts, max_duration=self.args.max_duration, shuffle=False