Use BucketingSampler for dev and test data

This commit is contained in:
Piotr Żelasko 2021-10-08 22:32:13 -04:00
parent adb068eb82
commit b682467e4d
2 changed files with 16 additions and 17 deletions

View File

@ -21,6 +21,10 @@ from functools import lru_cache
from pathlib import Path from pathlib import Path
from typing import List, Union from typing import List, Union
from torch.utils.data import DataLoader
from icefall.dataset.datamodule import DataModule
from icefall.utils import str2bool
from lhotse import CutSet, Fbank, FbankConfig, load_manifest from lhotse import CutSet, Fbank, FbankConfig, load_manifest
from lhotse.dataset import ( from lhotse.dataset import (
BucketingSampler, BucketingSampler,
@ -32,10 +36,6 @@ from lhotse.dataset import (
SpecAugment, SpecAugment,
) )
from lhotse.dataset.input_strategies import OnTheFlyFeatures from lhotse.dataset.input_strategies import OnTheFlyFeatures
from torch.utils.data import DataLoader
from icefall.dataset.datamodule import DataModule
from icefall.utils import str2bool
class LibriSpeechAsrDataModule(DataModule): class LibriSpeechAsrDataModule(DataModule):
@ -267,7 +267,7 @@ class LibriSpeechAsrDataModule(DataModule):
cut_transforms=transforms, cut_transforms=transforms,
return_cuts=self.args.return_cuts, return_cuts=self.args.return_cuts,
) )
valid_sampler = SingleCutSampler( valid_sampler = BucketingSampler(
cuts_valid, cuts_valid,
max_duration=self.args.max_duration, max_duration=self.args.max_duration,
shuffle=False, shuffle=False,
@ -300,12 +300,12 @@ class LibriSpeechAsrDataModule(DataModule):
else PrecomputedFeatures(), else PrecomputedFeatures(),
return_cuts=self.args.return_cuts, return_cuts=self.args.return_cuts,
) )
sampler = SingleCutSampler( sampler = BucketingSampler(
cuts_test, max_duration=self.args.max_duration cuts_test, max_duration=self.args.max_duration, shuffle=False
) )
logging.debug("About to create test dataloader") logging.debug("About to create test dataloader")
test_dl = DataLoader( test_dl = DataLoader(
test, batch_size=None, sampler=sampler, num_workers=1 test, batch_size=None, sampler=sampler, num_workers=self.args.num_workers
) )
test_loaders.append(test_dl) test_loaders.append(test_dl)

View File

@ -20,19 +20,18 @@ from functools import lru_cache
from pathlib import Path from pathlib import Path
from typing import List from typing import List
from torch.utils.data import DataLoader
from icefall.dataset.datamodule import DataModule
from icefall.utils import str2bool
from lhotse import CutSet, Fbank, FbankConfig, load_manifest from lhotse import CutSet, Fbank, FbankConfig, load_manifest
from lhotse.dataset import ( from lhotse.dataset import (
BucketingSampler, BucketingSampler,
CutConcatenate, CutConcatenate,
K2SpeechRecognitionDataset, K2SpeechRecognitionDataset,
PrecomputedFeatures, PrecomputedFeatures,
SingleCutSampler,
) )
from lhotse.dataset.input_strategies import OnTheFlyFeatures from lhotse.dataset.input_strategies import OnTheFlyFeatures
from torch.utils.data import DataLoader
from icefall.dataset.datamodule import DataModule
from icefall.utils import str2bool
class YesNoAsrDataModule(DataModule): class YesNoAsrDataModule(DataModule):
@ -198,7 +197,7 @@ class YesNoAsrDataModule(DataModule):
) )
else: else:
logging.info("Using SingleCutSampler.") logging.info("Using SingleCutSampler.")
train_sampler = SingleCutSampler( train_sampler = BucketingSampler(
cuts_train, cuts_train,
max_duration=self.args.max_duration, max_duration=self.args.max_duration,
shuffle=self.args.shuffle, shuffle=self.args.shuffle,
@ -226,12 +225,12 @@ class YesNoAsrDataModule(DataModule):
else PrecomputedFeatures(), else PrecomputedFeatures(),
return_cuts=self.args.return_cuts, return_cuts=self.args.return_cuts,
) )
sampler = SingleCutSampler( sampler = BucketingSampler(
cuts_test, max_duration=self.args.max_duration cuts_test, max_duration=self.args.max_duration, shuffle=False
) )
logging.debug("About to create test dataloader") logging.debug("About to create test dataloader")
test_dl = DataLoader( test_dl = DataLoader(
test, batch_size=None, sampler=sampler, num_workers=1 test, batch_size=None, sampler=sampler, num_workers=self.args.num_workers
) )
return test_dl return test_dl