Use BucketingSampler for valid and test dataloader

This commit is contained in:
Guanbo Wang 2022-04-06 19:05:45 -04:00
parent 3ddcc7939b
commit 79211633ed

View File

@ -22,6 +22,7 @@ from pathlib import Path
from lhotse import CutSet, Fbank, FbankConfig, load_manifest from lhotse import CutSet, Fbank, FbankConfig, load_manifest
from lhotse.dataset import ( from lhotse.dataset import (
BucketingSampler,
CutConcatenate, CutConcatenate,
CutMix, CutMix,
DynamicBucketingSampler, DynamicBucketingSampler,
@ -320,7 +321,7 @@ class GigaSpeechAsrDataModule:
cut_transforms=transforms, cut_transforms=transforms,
return_cuts=self.args.return_cuts, return_cuts=self.args.return_cuts,
) )
valid_sampler = DynamicBucketingSampler( valid_sampler = BucketingSampler(
cuts_valid, cuts_valid,
max_duration=self.args.max_duration, max_duration=self.args.max_duration,
shuffle=False, shuffle=False,
@ -344,7 +345,7 @@ class GigaSpeechAsrDataModule:
else PrecomputedFeatures(), else PrecomputedFeatures(),
return_cuts=self.args.return_cuts, return_cuts=self.args.return_cuts,
) )
sampler = DynamicBucketingSampler( sampler = BucketingSampler(
cuts, max_duration=self.args.max_duration, shuffle=False cuts, max_duration=self.args.max_duration, shuffle=False
) )
logging.debug("About to create test dataloader") logging.debug("About to create test dataloader")