DynamicBucketingSampler

This commit is contained in:
wgb14 2021-12-29 15:22:46 -05:00
parent bea78f6094
commit 6e5b189fc5
2 changed files with 12 additions and 9 deletions

View File

@ -22,9 +22,9 @@ from pathlib import Path
from lhotse import CutSet, Fbank, FbankConfig, load_manifest from lhotse import CutSet, Fbank, FbankConfig, load_manifest
from lhotse.dataset import ( from lhotse.dataset import (
BucketingSampler,
CutConcatenate, CutConcatenate,
CutMix, CutMix,
DynamicBucketingSampler,
K2SpeechRecognitionDataset, K2SpeechRecognitionDataset,
PrecomputedFeatures, PrecomputedFeatures,
SingleCutSampler, SingleCutSampler,
@ -82,7 +82,7 @@ class GigaSpeechAsrDataModule:
group.add_argument( group.add_argument(
"--bucketing-sampler", "--bucketing-sampler",
type=str2bool, type=str2bool,
default=False, default=True,
help="When enabled, the batches will come from buckets of " help="When enabled, the batches will come from buckets of "
"similar duration (saves padding frames).", "similar duration (saves padding frames).",
) )
@ -90,7 +90,7 @@ class GigaSpeechAsrDataModule:
"--num-buckets", "--num-buckets",
type=int, type=int,
default=30, default=30,
help="The number of buckets for the BucketingSampler" help="The number of buckets for the DynamicBucketingSampler"
"(you might want to increase it for larger datasets).", "(you might want to increase it for larger datasets).",
) )
group.add_argument( group.add_argument(
@ -142,7 +142,7 @@ class GigaSpeechAsrDataModule:
group.add_argument( group.add_argument(
"--num-workers", "--num-workers",
type=int, type=int,
default=20, default=2,
help="The number of training dataloader workers that " help="The number of training dataloader workers that "
"collect the batches.", "collect the batches.",
) )
@ -270,13 +270,12 @@ class GigaSpeechAsrDataModule:
) )
if self.args.bucketing_sampler: if self.args.bucketing_sampler:
logging.info("Using BucketingSampler.") logging.info("Using DynamicBucketingSampler.")
train_sampler = BucketingSampler( train_sampler = DynamicBucketingSampler(
cuts_train, cuts_train,
max_duration=self.args.max_duration, max_duration=self.args.max_duration,
shuffle=self.args.shuffle, shuffle=self.args.shuffle,
num_buckets=self.args.num_buckets, num_buckets=self.args.num_buckets,
bucket_method="equal_duration",
drop_last=True, drop_last=True,
) )
else: else:
@ -321,7 +320,7 @@ class GigaSpeechAsrDataModule:
cut_transforms=transforms, cut_transforms=transforms,
return_cuts=self.args.return_cuts, return_cuts=self.args.return_cuts,
) )
valid_sampler = BucketingSampler( valid_sampler = DynamicBucketingSampler(
cuts_valid, cuts_valid,
max_duration=self.args.max_duration, max_duration=self.args.max_duration,
shuffle=False, shuffle=False,
@ -345,7 +344,7 @@ class GigaSpeechAsrDataModule:
else PrecomputedFeatures(), else PrecomputedFeatures(),
return_cuts=self.args.return_cuts, return_cuts=self.args.return_cuts,
) )
sampler = BucketingSampler( sampler = DynamicBucketingSampler(
cuts, max_duration=self.args.max_duration, shuffle=False cuts, max_duration=self.args.max_duration, shuffle=False
) )
logging.debug("About to create test dataloader") logging.debug("About to create test dataloader")

View File

@ -72,9 +72,13 @@ def compute_fbank_gigaspeech_dev_test():
batch_duration=batch_duration, batch_duration=batch_duration,
storage_type=LilcomHdf5Writer, storage_type=LilcomHdf5Writer,
) )
cut_set = cut_set.trim_to_supervisions(
keep_overlapping=False, min_duration=None
)
logging.info(f"Saving to {cuts_path}") logging.info(f"Saving to {cuts_path}")
cut_set.to_file(cuts_path) cut_set.to_file(cuts_path)
logging.info(f"Saved to {cuts_path}")
def main(): def main():