diff --git a/egs/gigaspeech/ASR/conformer_ctc/asr_datamodule.py b/egs/gigaspeech/ASR/conformer_ctc/asr_datamodule.py index d29195ad2..7faea33f5 100644 --- a/egs/gigaspeech/ASR/conformer_ctc/asr_datamodule.py +++ b/egs/gigaspeech/ASR/conformer_ctc/asr_datamodule.py @@ -22,9 +22,9 @@ from pathlib import Path from lhotse import CutSet, Fbank, FbankConfig, load_manifest from lhotse.dataset import ( - BucketingSampler, CutConcatenate, CutMix, + DynamicBucketingSampler, K2SpeechRecognitionDataset, PrecomputedFeatures, SingleCutSampler, @@ -82,7 +82,7 @@ class GigaSpeechAsrDataModule: group.add_argument( "--bucketing-sampler", type=str2bool, - default=False, + default=True, help="When enabled, the batches will come from buckets of " "similar duration (saves padding frames).", ) @@ -90,7 +90,7 @@ class GigaSpeechAsrDataModule: "--num-buckets", type=int, default=30, - help="The number of buckets for the BucketingSampler" + help="The number of buckets for the DynamicBucketingSampler" "(you might want to increase it for larger datasets).", ) group.add_argument( @@ -142,7 +142,7 @@ class GigaSpeechAsrDataModule: group.add_argument( "--num-workers", type=int, - default=20, + default=2, help="The number of training dataloader workers that " "collect the batches.", ) @@ -270,13 +270,12 @@ class GigaSpeechAsrDataModule: ) if self.args.bucketing_sampler: - logging.info("Using BucketingSampler.") - train_sampler = BucketingSampler( + logging.info("Using DynamicBucketingSampler.") + train_sampler = DynamicBucketingSampler( cuts_train, max_duration=self.args.max_duration, shuffle=self.args.shuffle, num_buckets=self.args.num_buckets, - bucket_method="equal_duration", drop_last=True, ) else: @@ -321,7 +320,7 @@ class GigaSpeechAsrDataModule: cut_transforms=transforms, return_cuts=self.args.return_cuts, ) - valid_sampler = BucketingSampler( + valid_sampler = DynamicBucketingSampler( cuts_valid, max_duration=self.args.max_duration, shuffle=False, @@ -345,7 +344,7 @@ class GigaSpeechAsrDataModule: else PrecomputedFeatures(), return_cuts=self.args.return_cuts, ) - sampler = BucketingSampler( + sampler = DynamicBucketingSampler( cuts, max_duration=self.args.max_duration, shuffle=False ) logging.debug("About to create test dataloader") diff --git a/egs/gigaspeech/ASR/local/compute_fbank_gigaspeech_dev_test.py b/egs/gigaspeech/ASR/local/compute_fbank_gigaspeech_dev_test.py index 59f60939b..53b82633c 100755 --- a/egs/gigaspeech/ASR/local/compute_fbank_gigaspeech_dev_test.py +++ b/egs/gigaspeech/ASR/local/compute_fbank_gigaspeech_dev_test.py @@ -72,9 +72,13 @@ def compute_fbank_gigaspeech_dev_test(): batch_duration=batch_duration, storage_type=LilcomHdf5Writer, ) + cut_set = cut_set.trim_to_supervisions( + keep_overlapping=False, min_duration=None + ) logging.info(f"Saving to {cuts_path}") cut_set.to_file(cuts_path) + logging.info(f"Saved to {cuts_path}") def main():