mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 18:24:18 +00:00
DynamicBucketingSampler
This commit is contained in:
parent
bea78f6094
commit
6e5b189fc5
@ -22,9 +22,9 @@ from pathlib import Path
|
||||
|
||||
from lhotse import CutSet, Fbank, FbankConfig, load_manifest
|
||||
from lhotse.dataset import (
|
||||
BucketingSampler,
|
||||
CutConcatenate,
|
||||
CutMix,
|
||||
DynamicBucketingSampler,
|
||||
K2SpeechRecognitionDataset,
|
||||
PrecomputedFeatures,
|
||||
SingleCutSampler,
|
||||
@ -82,7 +82,7 @@ class GigaSpeechAsrDataModule:
|
||||
group.add_argument(
|
||||
"--bucketing-sampler",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
default=True,
|
||||
help="When enabled, the batches will come from buckets of "
|
||||
"similar duration (saves padding frames).",
|
||||
)
|
||||
@ -90,7 +90,7 @@ class GigaSpeechAsrDataModule:
|
||||
"--num-buckets",
|
||||
type=int,
|
||||
default=30,
|
||||
help="The number of buckets for the BucketingSampler"
|
||||
help="The number of buckets for the DynamicBucketingSampler"
|
||||
"(you might want to increase it for larger datasets).",
|
||||
)
|
||||
group.add_argument(
|
||||
@ -142,7 +142,7 @@ class GigaSpeechAsrDataModule:
|
||||
group.add_argument(
|
||||
"--num-workers",
|
||||
type=int,
|
||||
default=20,
|
||||
default=2,
|
||||
help="The number of training dataloader workers that "
|
||||
"collect the batches.",
|
||||
)
|
||||
@ -270,13 +270,12 @@ class GigaSpeechAsrDataModule:
|
||||
)
|
||||
|
||||
if self.args.bucketing_sampler:
|
||||
logging.info("Using BucketingSampler.")
|
||||
train_sampler = BucketingSampler(
|
||||
logging.info("Using DynamicBucketingSampler.")
|
||||
train_sampler = DynamicBucketingSampler(
|
||||
cuts_train,
|
||||
max_duration=self.args.max_duration,
|
||||
shuffle=self.args.shuffle,
|
||||
num_buckets=self.args.num_buckets,
|
||||
bucket_method="equal_duration",
|
||||
drop_last=True,
|
||||
)
|
||||
else:
|
||||
@ -321,7 +320,7 @@ class GigaSpeechAsrDataModule:
|
||||
cut_transforms=transforms,
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
valid_sampler = BucketingSampler(
|
||||
valid_sampler = DynamicBucketingSampler(
|
||||
cuts_valid,
|
||||
max_duration=self.args.max_duration,
|
||||
shuffle=False,
|
||||
@ -345,7 +344,7 @@ class GigaSpeechAsrDataModule:
|
||||
else PrecomputedFeatures(),
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
sampler = BucketingSampler(
|
||||
sampler = DynamicBucketingSampler(
|
||||
cuts, max_duration=self.args.max_duration, shuffle=False
|
||||
)
|
||||
logging.debug("About to create test dataloader")
|
||||
|
@ -72,9 +72,13 @@ def compute_fbank_gigaspeech_dev_test():
|
||||
batch_duration=batch_duration,
|
||||
storage_type=LilcomHdf5Writer,
|
||||
)
|
||||
cut_set = cut_set.trim_to_supervisions(
|
||||
keep_overlapping=False, min_duration=None
|
||||
)
|
||||
|
||||
logging.info(f"Saving to {cuts_path}")
|
||||
cut_set.to_file(cuts_path)
|
||||
logging.info(f"Saved to {cuts_path}")
|
||||
|
||||
|
||||
def main():
|
||||
|
Loading…
x
Reference in New Issue
Block a user