mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-08 09:32:20 +00:00
Use BucketingSampler for dev and test data
This commit is contained in:
parent
adb068eb82
commit
b682467e4d
@ -21,6 +21,10 @@ from functools import lru_cache
|
||||
from pathlib import Path
|
||||
from typing import List, Union
|
||||
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from icefall.dataset.datamodule import DataModule
|
||||
from icefall.utils import str2bool
|
||||
from lhotse import CutSet, Fbank, FbankConfig, load_manifest
|
||||
from lhotse.dataset import (
|
||||
BucketingSampler,
|
||||
@ -32,10 +36,6 @@ from lhotse.dataset import (
|
||||
SpecAugment,
|
||||
)
|
||||
from lhotse.dataset.input_strategies import OnTheFlyFeatures
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from icefall.dataset.datamodule import DataModule
|
||||
from icefall.utils import str2bool
|
||||
|
||||
|
||||
class LibriSpeechAsrDataModule(DataModule):
|
||||
@ -267,7 +267,7 @@ class LibriSpeechAsrDataModule(DataModule):
|
||||
cut_transforms=transforms,
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
valid_sampler = SingleCutSampler(
|
||||
valid_sampler = BucketingSampler(
|
||||
cuts_valid,
|
||||
max_duration=self.args.max_duration,
|
||||
shuffle=False,
|
||||
@ -300,12 +300,12 @@ class LibriSpeechAsrDataModule(DataModule):
|
||||
else PrecomputedFeatures(),
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
sampler = SingleCutSampler(
|
||||
cuts_test, max_duration=self.args.max_duration
|
||||
sampler = BucketingSampler(
|
||||
cuts_test, max_duration=self.args.max_duration, shuffle=False
|
||||
)
|
||||
logging.debug("About to create test dataloader")
|
||||
test_dl = DataLoader(
|
||||
test, batch_size=None, sampler=sampler, num_workers=1
|
||||
test, batch_size=None, sampler=sampler, num_workers=self.args.num_workers
|
||||
)
|
||||
test_loaders.append(test_dl)
|
||||
|
||||
|
@ -20,19 +20,18 @@ from functools import lru_cache
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from icefall.dataset.datamodule import DataModule
|
||||
from icefall.utils import str2bool
|
||||
from lhotse import CutSet, Fbank, FbankConfig, load_manifest
|
||||
from lhotse.dataset import (
|
||||
BucketingSampler,
|
||||
CutConcatenate,
|
||||
K2SpeechRecognitionDataset,
|
||||
PrecomputedFeatures,
|
||||
SingleCutSampler,
|
||||
)
|
||||
from lhotse.dataset.input_strategies import OnTheFlyFeatures
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from icefall.dataset.datamodule import DataModule
|
||||
from icefall.utils import str2bool
|
||||
|
||||
|
||||
class YesNoAsrDataModule(DataModule):
|
||||
@ -198,7 +197,7 @@ class YesNoAsrDataModule(DataModule):
|
||||
)
|
||||
else:
|
||||
logging.info("Using SingleCutSampler.")
|
||||
train_sampler = SingleCutSampler(
|
||||
train_sampler = BucketingSampler(
|
||||
cuts_train,
|
||||
max_duration=self.args.max_duration,
|
||||
shuffle=self.args.shuffle,
|
||||
@ -226,12 +225,12 @@ class YesNoAsrDataModule(DataModule):
|
||||
else PrecomputedFeatures(),
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
sampler = SingleCutSampler(
|
||||
cuts_test, max_duration=self.args.max_duration
|
||||
sampler = BucketingSampler(
|
||||
cuts_test, max_duration=self.args.max_duration, shuffle=False
|
||||
)
|
||||
logging.debug("About to create test dataloader")
|
||||
test_dl = DataLoader(
|
||||
test, batch_size=None, sampler=sampler, num_workers=1
|
||||
test, batch_size=None, sampler=sampler, num_workers=self.args.num_workers
|
||||
)
|
||||
return test_dl
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user