diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py index 8290e71d1..229575db6 100644 --- a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py +++ b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py @@ -21,6 +21,10 @@ from functools import lru_cache from pathlib import Path from typing import List, Union +from torch.utils.data import DataLoader + +from icefall.dataset.datamodule import DataModule +from icefall.utils import str2bool from lhotse import CutSet, Fbank, FbankConfig, load_manifest from lhotse.dataset import ( BucketingSampler, @@ -32,10 +36,6 @@ from lhotse.dataset import ( SpecAugment, ) from lhotse.dataset.input_strategies import OnTheFlyFeatures -from torch.utils.data import DataLoader - -from icefall.dataset.datamodule import DataModule -from icefall.utils import str2bool class LibriSpeechAsrDataModule(DataModule): @@ -267,7 +267,7 @@ class LibriSpeechAsrDataModule(DataModule): cut_transforms=transforms, return_cuts=self.args.return_cuts, ) - valid_sampler = SingleCutSampler( + valid_sampler = BucketingSampler( cuts_valid, max_duration=self.args.max_duration, shuffle=False, @@ -300,12 +300,15 @@ class LibriSpeechAsrDataModule(DataModule): else PrecomputedFeatures(), return_cuts=self.args.return_cuts, ) - sampler = SingleCutSampler( - cuts_test, max_duration=self.args.max_duration + sampler = BucketingSampler( + cuts_test, max_duration=self.args.max_duration, shuffle=False ) logging.debug("About to create test dataloader") test_dl = DataLoader( - test, batch_size=None, sampler=sampler, num_workers=1 + test, + batch_size=None, + sampler=sampler, + num_workers=self.args.num_workers, ) test_loaders.append(test_dl) diff --git a/egs/yesno/ASR/tdnn/asr_datamodule.py b/egs/yesno/ASR/tdnn/asr_datamodule.py index e6614e3ce..832fd556e 100644 --- a/egs/yesno/ASR/tdnn/asr_datamodule.py +++ b/egs/yesno/ASR/tdnn/asr_datamodule.py @@ -20,19 +20,18 @@ from functools import lru_cache from pathlib import Path from typing import List +from torch.utils.data import DataLoader + +from icefall.dataset.datamodule import DataModule +from icefall.utils import str2bool from lhotse import CutSet, Fbank, FbankConfig, load_manifest from lhotse.dataset import ( BucketingSampler, CutConcatenate, K2SpeechRecognitionDataset, PrecomputedFeatures, - SingleCutSampler, ) from lhotse.dataset.input_strategies import OnTheFlyFeatures -from torch.utils.data import DataLoader - -from icefall.dataset.datamodule import DataModule -from icefall.utils import str2bool class YesNoAsrDataModule(DataModule): @@ -198,7 +197,7 @@ class YesNoAsrDataModule(DataModule): ) else: logging.info("Using SingleCutSampler.") - train_sampler = SingleCutSampler( + train_sampler = BucketingSampler( cuts_train, max_duration=self.args.max_duration, shuffle=self.args.shuffle, @@ -226,12 +225,15 @@ class YesNoAsrDataModule(DataModule): else PrecomputedFeatures(), return_cuts=self.args.return_cuts, ) - sampler = SingleCutSampler( - cuts_test, max_duration=self.args.max_duration + sampler = BucketingSampler( + cuts_test, max_duration=self.args.max_duration, shuffle=False ) logging.debug("About to create test dataloader") test_dl = DataLoader( - test, batch_size=None, sampler=sampler, num_workers=1 + test, + batch_size=None, + sampler=sampler, + num_workers=self.args.num_workers, ) return test_dl