mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
Merge pull request #73 from pzelasko/feature/bucketing-in-test
Use BucketingSampler for dev and test data
This commit is contained in:
commit
d54828e73a
@ -21,6 +21,10 @@ from functools import lru_cache
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List, Union
|
from typing import List, Union
|
||||||
|
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
|
from icefall.dataset.datamodule import DataModule
|
||||||
|
from icefall.utils import str2bool
|
||||||
from lhotse import CutSet, Fbank, FbankConfig, load_manifest
|
from lhotse import CutSet, Fbank, FbankConfig, load_manifest
|
||||||
from lhotse.dataset import (
|
from lhotse.dataset import (
|
||||||
BucketingSampler,
|
BucketingSampler,
|
||||||
@ -32,10 +36,6 @@ from lhotse.dataset import (
|
|||||||
SpecAugment,
|
SpecAugment,
|
||||||
)
|
)
|
||||||
from lhotse.dataset.input_strategies import OnTheFlyFeatures
|
from lhotse.dataset.input_strategies import OnTheFlyFeatures
|
||||||
from torch.utils.data import DataLoader
|
|
||||||
|
|
||||||
from icefall.dataset.datamodule import DataModule
|
|
||||||
from icefall.utils import str2bool
|
|
||||||
|
|
||||||
|
|
||||||
class LibriSpeechAsrDataModule(DataModule):
|
class LibriSpeechAsrDataModule(DataModule):
|
||||||
@ -267,7 +267,7 @@ class LibriSpeechAsrDataModule(DataModule):
|
|||||||
cut_transforms=transforms,
|
cut_transforms=transforms,
|
||||||
return_cuts=self.args.return_cuts,
|
return_cuts=self.args.return_cuts,
|
||||||
)
|
)
|
||||||
valid_sampler = SingleCutSampler(
|
valid_sampler = BucketingSampler(
|
||||||
cuts_valid,
|
cuts_valid,
|
||||||
max_duration=self.args.max_duration,
|
max_duration=self.args.max_duration,
|
||||||
shuffle=False,
|
shuffle=False,
|
||||||
@ -300,12 +300,15 @@ class LibriSpeechAsrDataModule(DataModule):
|
|||||||
else PrecomputedFeatures(),
|
else PrecomputedFeatures(),
|
||||||
return_cuts=self.args.return_cuts,
|
return_cuts=self.args.return_cuts,
|
||||||
)
|
)
|
||||||
sampler = SingleCutSampler(
|
sampler = BucketingSampler(
|
||||||
cuts_test, max_duration=self.args.max_duration
|
cuts_test, max_duration=self.args.max_duration, shuffle=False
|
||||||
)
|
)
|
||||||
logging.debug("About to create test dataloader")
|
logging.debug("About to create test dataloader")
|
||||||
test_dl = DataLoader(
|
test_dl = DataLoader(
|
||||||
test, batch_size=None, sampler=sampler, num_workers=1
|
test,
|
||||||
|
batch_size=None,
|
||||||
|
sampler=sampler,
|
||||||
|
num_workers=self.args.num_workers,
|
||||||
)
|
)
|
||||||
test_loaders.append(test_dl)
|
test_loaders.append(test_dl)
|
||||||
|
|
||||||
|
@ -20,19 +20,18 @@ from functools import lru_cache
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
|
from icefall.dataset.datamodule import DataModule
|
||||||
|
from icefall.utils import str2bool
|
||||||
from lhotse import CutSet, Fbank, FbankConfig, load_manifest
|
from lhotse import CutSet, Fbank, FbankConfig, load_manifest
|
||||||
from lhotse.dataset import (
|
from lhotse.dataset import (
|
||||||
BucketingSampler,
|
BucketingSampler,
|
||||||
CutConcatenate,
|
CutConcatenate,
|
||||||
K2SpeechRecognitionDataset,
|
K2SpeechRecognitionDataset,
|
||||||
PrecomputedFeatures,
|
PrecomputedFeatures,
|
||||||
SingleCutSampler,
|
|
||||||
)
|
)
|
||||||
from lhotse.dataset.input_strategies import OnTheFlyFeatures
|
from lhotse.dataset.input_strategies import OnTheFlyFeatures
|
||||||
from torch.utils.data import DataLoader
|
|
||||||
|
|
||||||
from icefall.dataset.datamodule import DataModule
|
|
||||||
from icefall.utils import str2bool
|
|
||||||
|
|
||||||
|
|
||||||
class YesNoAsrDataModule(DataModule):
|
class YesNoAsrDataModule(DataModule):
|
||||||
@ -198,7 +197,7 @@ class YesNoAsrDataModule(DataModule):
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logging.info("Using SingleCutSampler.")
|
logging.info("Using SingleCutSampler.")
|
||||||
train_sampler = SingleCutSampler(
|
train_sampler = BucketingSampler(
|
||||||
cuts_train,
|
cuts_train,
|
||||||
max_duration=self.args.max_duration,
|
max_duration=self.args.max_duration,
|
||||||
shuffle=self.args.shuffle,
|
shuffle=self.args.shuffle,
|
||||||
@ -226,12 +225,15 @@ class YesNoAsrDataModule(DataModule):
|
|||||||
else PrecomputedFeatures(),
|
else PrecomputedFeatures(),
|
||||||
return_cuts=self.args.return_cuts,
|
return_cuts=self.args.return_cuts,
|
||||||
)
|
)
|
||||||
sampler = SingleCutSampler(
|
sampler = BucketingSampler(
|
||||||
cuts_test, max_duration=self.args.max_duration
|
cuts_test, max_duration=self.args.max_duration, shuffle=False
|
||||||
)
|
)
|
||||||
logging.debug("About to create test dataloader")
|
logging.debug("About to create test dataloader")
|
||||||
test_dl = DataLoader(
|
test_dl = DataLoader(
|
||||||
test, batch_size=None, sampler=sampler, num_workers=1
|
test,
|
||||||
|
batch_size=None,
|
||||||
|
sampler=sampler,
|
||||||
|
num_workers=self.args.num_workers,
|
||||||
)
|
)
|
||||||
return test_dl
|
return test_dl
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user