Merge pull request #73 from pzelasko/feature/bucketing-in-test

Use BucketingSampler for dev and test data
This commit is contained in:
Piotr Żelasko 2021-10-09 10:58:29 -04:00 committed by GitHub
commit d54828e73a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 22 additions and 17 deletions

View File

@ -21,6 +21,10 @@ from functools import lru_cache
from pathlib import Path
from typing import List, Union
from torch.utils.data import DataLoader
from icefall.dataset.datamodule import DataModule
from icefall.utils import str2bool
from lhotse import CutSet, Fbank, FbankConfig, load_manifest
from lhotse.dataset import (
BucketingSampler,
@ -32,10 +36,6 @@ from lhotse.dataset import (
SpecAugment,
)
from lhotse.dataset.input_strategies import OnTheFlyFeatures
from torch.utils.data import DataLoader
from icefall.dataset.datamodule import DataModule
from icefall.utils import str2bool
class LibriSpeechAsrDataModule(DataModule):
@ -267,7 +267,7 @@ class LibriSpeechAsrDataModule(DataModule):
cut_transforms=transforms,
return_cuts=self.args.return_cuts,
)
valid_sampler = SingleCutSampler(
valid_sampler = BucketingSampler(
cuts_valid,
max_duration=self.args.max_duration,
shuffle=False,
@ -300,12 +300,15 @@ class LibriSpeechAsrDataModule(DataModule):
else PrecomputedFeatures(),
return_cuts=self.args.return_cuts,
)
sampler = SingleCutSampler(
cuts_test, max_duration=self.args.max_duration
sampler = BucketingSampler(
cuts_test, max_duration=self.args.max_duration, shuffle=False
)
logging.debug("About to create test dataloader")
test_dl = DataLoader(
test, batch_size=None, sampler=sampler, num_workers=1
test,
batch_size=None,
sampler=sampler,
num_workers=self.args.num_workers,
)
test_loaders.append(test_dl)

View File

@ -20,19 +20,18 @@ from functools import lru_cache
from pathlib import Path
from typing import List
from torch.utils.data import DataLoader
from icefall.dataset.datamodule import DataModule
from icefall.utils import str2bool
from lhotse import CutSet, Fbank, FbankConfig, load_manifest
from lhotse.dataset import (
BucketingSampler,
CutConcatenate,
K2SpeechRecognitionDataset,
PrecomputedFeatures,
SingleCutSampler,
)
from lhotse.dataset.input_strategies import OnTheFlyFeatures
from torch.utils.data import DataLoader
from icefall.dataset.datamodule import DataModule
from icefall.utils import str2bool
class YesNoAsrDataModule(DataModule):
@ -198,7 +197,7 @@ class YesNoAsrDataModule(DataModule):
)
else:
logging.info("Using SingleCutSampler.")
train_sampler = SingleCutSampler(
train_sampler = BucketingSampler(
cuts_train,
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
@ -226,12 +225,15 @@ class YesNoAsrDataModule(DataModule):
else PrecomputedFeatures(),
return_cuts=self.args.return_cuts,
)
sampler = SingleCutSampler(
cuts_test, max_duration=self.args.max_duration
sampler = BucketingSampler(
cuts_test, max_duration=self.args.max_duration, shuffle=False
)
logging.debug("About to create test dataloader")
test_dl = DataLoader(
test, batch_size=None, sampler=sampler, num_workers=1
test,
batch_size=None,
sampler=sampler,
num_workers=self.args.num_workers,
)
return test_dl