DynamicBucketingSampler

This commit is contained in:
wgb14 2021-12-29 15:22:46 -05:00
parent bea78f6094
commit 6e5b189fc5
2 changed files with 12 additions and 9 deletions

View File

@ -22,9 +22,9 @@ from pathlib import Path
from lhotse import CutSet, Fbank, FbankConfig, load_manifest
from lhotse.dataset import (
BucketingSampler,
CutConcatenate,
CutMix,
DynamicBucketingSampler,
K2SpeechRecognitionDataset,
PrecomputedFeatures,
SingleCutSampler,
@ -82,7 +82,7 @@ class GigaSpeechAsrDataModule:
group.add_argument(
"--bucketing-sampler",
type=str2bool,
default=False,
default=True,
help="When enabled, the batches will come from buckets of "
"similar duration (saves padding frames).",
)
@ -90,7 +90,7 @@ class GigaSpeechAsrDataModule:
"--num-buckets",
type=int,
default=30,
help="The number of buckets for the BucketingSampler"
help="The number of buckets for the DynamicBucketingSampler"
"(you might want to increase it for larger datasets).",
)
group.add_argument(
@ -142,7 +142,7 @@ class GigaSpeechAsrDataModule:
group.add_argument(
"--num-workers",
type=int,
default=20,
default=2,
help="The number of training dataloader workers that "
"collect the batches.",
)
@ -270,13 +270,12 @@ class GigaSpeechAsrDataModule:
)
if self.args.bucketing_sampler:
logging.info("Using BucketingSampler.")
train_sampler = BucketingSampler(
logging.info("Using DynamicBucketingSampler.")
train_sampler = DynamicBucketingSampler(
cuts_train,
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
num_buckets=self.args.num_buckets,
bucket_method="equal_duration",
drop_last=True,
)
else:
@ -321,7 +320,7 @@ class GigaSpeechAsrDataModule:
cut_transforms=transforms,
return_cuts=self.args.return_cuts,
)
valid_sampler = BucketingSampler(
valid_sampler = DynamicBucketingSampler(
cuts_valid,
max_duration=self.args.max_duration,
shuffle=False,
@ -345,7 +344,7 @@ class GigaSpeechAsrDataModule:
else PrecomputedFeatures(),
return_cuts=self.args.return_cuts,
)
sampler = BucketingSampler(
sampler = DynamicBucketingSampler(
cuts, max_duration=self.args.max_duration, shuffle=False
)
logging.debug("About to create test dataloader")

View File

@ -72,9 +72,13 @@ def compute_fbank_gigaspeech_dev_test():
batch_duration=batch_duration,
storage_type=LilcomHdf5Writer,
)
cut_set = cut_set.trim_to_supervisions(
keep_overlapping=False, min_duration=None
)
logging.info(f"Saving to {cuts_path}")
cut_set.to_file(cuts_path)
logging.info(f"Saved to {cuts_path}")
def main():