From b682467e4d5b3cc085d7fdb988f61e98979ae6a0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Fri, 8 Oct 2021 22:32:13 -0400 Subject: [PATCH 1/2] Use BucketingSampler for dev and test data --- .../ASR/tdnn_lstm_ctc/asr_datamodule.py | 16 ++++++++-------- egs/yesno/ASR/tdnn/asr_datamodule.py | 17 ++++++++--------- 2 files changed, 16 insertions(+), 17 deletions(-) diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py index 8290e71d1..4953e8538 100644 --- a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py +++ b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py @@ -21,6 +21,10 @@ from functools import lru_cache from pathlib import Path from typing import List, Union +from torch.utils.data import DataLoader + +from icefall.dataset.datamodule import DataModule +from icefall.utils import str2bool from lhotse import CutSet, Fbank, FbankConfig, load_manifest from lhotse.dataset import ( BucketingSampler, @@ -32,10 +36,6 @@ from lhotse.dataset import ( SpecAugment, ) from lhotse.dataset.input_strategies import OnTheFlyFeatures -from torch.utils.data import DataLoader - -from icefall.dataset.datamodule import DataModule -from icefall.utils import str2bool class LibriSpeechAsrDataModule(DataModule): @@ -267,7 +267,7 @@ class LibriSpeechAsrDataModule(DataModule): cut_transforms=transforms, return_cuts=self.args.return_cuts, ) - valid_sampler = SingleCutSampler( + valid_sampler = BucketingSampler( cuts_valid, max_duration=self.args.max_duration, shuffle=False, @@ -300,12 +300,12 @@ class LibriSpeechAsrDataModule(DataModule): else PrecomputedFeatures(), return_cuts=self.args.return_cuts, ) - sampler = SingleCutSampler( - cuts_test, max_duration=self.args.max_duration + sampler = BucketingSampler( + cuts_test, max_duration=self.args.max_duration, shuffle=False ) logging.debug("About to create test dataloader") test_dl = DataLoader( - test, batch_size=None, sampler=sampler, num_workers=1 + test, batch_size=None, sampler=sampler, num_workers=self.args.num_workers ) test_loaders.append(test_dl) diff --git a/egs/yesno/ASR/tdnn/asr_datamodule.py b/egs/yesno/ASR/tdnn/asr_datamodule.py index e6614e3ce..a9a6145f0 100644 --- a/egs/yesno/ASR/tdnn/asr_datamodule.py +++ b/egs/yesno/ASR/tdnn/asr_datamodule.py @@ -20,19 +20,18 @@ from functools import lru_cache from pathlib import Path from typing import List +from torch.utils.data import DataLoader + +from icefall.dataset.datamodule import DataModule +from icefall.utils import str2bool from lhotse import CutSet, Fbank, FbankConfig, load_manifest from lhotse.dataset import ( BucketingSampler, CutConcatenate, K2SpeechRecognitionDataset, PrecomputedFeatures, - SingleCutSampler, ) from lhotse.dataset.input_strategies import OnTheFlyFeatures -from torch.utils.data import DataLoader - -from icefall.dataset.datamodule import DataModule -from icefall.utils import str2bool class YesNoAsrDataModule(DataModule): @@ -198,7 +197,7 @@ class YesNoAsrDataModule(DataModule): ) else: logging.info("Using SingleCutSampler.") - train_sampler = SingleCutSampler( + train_sampler = BucketingSampler( cuts_train, max_duration=self.args.max_duration, shuffle=self.args.shuffle, @@ -226,12 +225,12 @@ class YesNoAsrDataModule(DataModule): else PrecomputedFeatures(), return_cuts=self.args.return_cuts, ) - sampler = SingleCutSampler( - cuts_test, max_duration=self.args.max_duration + sampler = BucketingSampler( + cuts_test, max_duration=self.args.max_duration, shuffle=False ) logging.debug("About to create test dataloader") test_dl = DataLoader( - test, batch_size=None, sampler=sampler, num_workers=1 + test, batch_size=None, sampler=sampler, num_workers=self.args.num_workers ) return test_dl From 069ebaf9bab80209359b5ed19a3756cde2a551f7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Sat, 9 Oct 2021 14:45:46 +0000 Subject: [PATCH 2/2] Reformatting --- egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py | 5 ++++- egs/yesno/ASR/tdnn/asr_datamodule.py | 5 ++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py index 4953e8538..229575db6 100644 --- a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py +++ b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py @@ -305,7 +305,10 @@ class LibriSpeechAsrDataModule(DataModule): ) logging.debug("About to create test dataloader") test_dl = DataLoader( - test, batch_size=None, sampler=sampler, num_workers=self.args.num_workers + test, + batch_size=None, + sampler=sampler, + num_workers=self.args.num_workers, ) test_loaders.append(test_dl) diff --git a/egs/yesno/ASR/tdnn/asr_datamodule.py b/egs/yesno/ASR/tdnn/asr_datamodule.py index a9a6145f0..832fd556e 100644 --- a/egs/yesno/ASR/tdnn/asr_datamodule.py +++ b/egs/yesno/ASR/tdnn/asr_datamodule.py @@ -230,7 +230,10 @@ class YesNoAsrDataModule(DataModule): ) logging.debug("About to create test dataloader") test_dl = DataLoader( - test, batch_size=None, sampler=sampler, num_workers=self.args.num_workers + test, + batch_size=None, + sampler=sampler, + num_workers=self.args.num_workers, ) return test_dl