Merge pull request #73 from pzelasko/feature/bucketing-in-test

Use BucketingSampler for dev and test data
This commit is contained in:
Piotr Żelasko 2021-10-09 10:58:29 -04:00 committed by GitHub
commit d54828e73a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 22 additions and 17 deletions

View File

@ -21,6 +21,10 @@ from functools import lru_cache
from pathlib import Path from pathlib import Path
from typing import List, Union from typing import List, Union
from torch.utils.data import DataLoader
from icefall.dataset.datamodule import DataModule
from icefall.utils import str2bool
from lhotse import CutSet, Fbank, FbankConfig, load_manifest from lhotse import CutSet, Fbank, FbankConfig, load_manifest
from lhotse.dataset import ( from lhotse.dataset import (
BucketingSampler, BucketingSampler,
@ -32,10 +36,6 @@ from lhotse.dataset import (
SpecAugment, SpecAugment,
) )
from lhotse.dataset.input_strategies import OnTheFlyFeatures from lhotse.dataset.input_strategies import OnTheFlyFeatures
from torch.utils.data import DataLoader
from icefall.dataset.datamodule import DataModule
from icefall.utils import str2bool
class LibriSpeechAsrDataModule(DataModule): class LibriSpeechAsrDataModule(DataModule):
@ -267,7 +267,7 @@ class LibriSpeechAsrDataModule(DataModule):
cut_transforms=transforms, cut_transforms=transforms,
return_cuts=self.args.return_cuts, return_cuts=self.args.return_cuts,
) )
valid_sampler = SingleCutSampler( valid_sampler = BucketingSampler(
cuts_valid, cuts_valid,
max_duration=self.args.max_duration, max_duration=self.args.max_duration,
shuffle=False, shuffle=False,
@ -300,12 +300,15 @@ class LibriSpeechAsrDataModule(DataModule):
else PrecomputedFeatures(), else PrecomputedFeatures(),
return_cuts=self.args.return_cuts, return_cuts=self.args.return_cuts,
) )
sampler = SingleCutSampler( sampler = BucketingSampler(
cuts_test, max_duration=self.args.max_duration cuts_test, max_duration=self.args.max_duration, shuffle=False
) )
logging.debug("About to create test dataloader") logging.debug("About to create test dataloader")
test_dl = DataLoader( test_dl = DataLoader(
test, batch_size=None, sampler=sampler, num_workers=1 test,
batch_size=None,
sampler=sampler,
num_workers=self.args.num_workers,
) )
test_loaders.append(test_dl) test_loaders.append(test_dl)

View File

@ -20,19 +20,18 @@ from functools import lru_cache
from pathlib import Path from pathlib import Path
from typing import List from typing import List
from torch.utils.data import DataLoader
from icefall.dataset.datamodule import DataModule
from icefall.utils import str2bool
from lhotse import CutSet, Fbank, FbankConfig, load_manifest from lhotse import CutSet, Fbank, FbankConfig, load_manifest
from lhotse.dataset import ( from lhotse.dataset import (
BucketingSampler, BucketingSampler,
CutConcatenate, CutConcatenate,
K2SpeechRecognitionDataset, K2SpeechRecognitionDataset,
PrecomputedFeatures, PrecomputedFeatures,
SingleCutSampler,
) )
from lhotse.dataset.input_strategies import OnTheFlyFeatures from lhotse.dataset.input_strategies import OnTheFlyFeatures
from torch.utils.data import DataLoader
from icefall.dataset.datamodule import DataModule
from icefall.utils import str2bool
class YesNoAsrDataModule(DataModule): class YesNoAsrDataModule(DataModule):
@ -198,7 +197,7 @@ class YesNoAsrDataModule(DataModule):
) )
else: else:
logging.info("Using SingleCutSampler.") logging.info("Using SingleCutSampler.")
train_sampler = SingleCutSampler( train_sampler = BucketingSampler(
cuts_train, cuts_train,
max_duration=self.args.max_duration, max_duration=self.args.max_duration,
shuffle=self.args.shuffle, shuffle=self.args.shuffle,
@ -226,12 +225,15 @@ class YesNoAsrDataModule(DataModule):
else PrecomputedFeatures(), else PrecomputedFeatures(),
return_cuts=self.args.return_cuts, return_cuts=self.args.return_cuts,
) )
sampler = SingleCutSampler( sampler = BucketingSampler(
cuts_test, max_duration=self.args.max_duration cuts_test, max_duration=self.args.max_duration, shuffle=False
) )
logging.debug("About to create test dataloader") logging.debug("About to create test dataloader")
test_dl = DataLoader( test_dl = DataLoader(
test, batch_size=None, sampler=sampler, num_workers=1 test,
batch_size=None,
sampler=sampler,
num_workers=self.args.num_workers,
) )
return test_dl return test_dl