lazy loading and use SingleCutSampler

This commit is contained in:
wgb14 2021-12-17 00:38:52 -05:00
parent 532309bf72
commit bea78f6094
2 changed files with 15 additions and 6 deletions

View File

@ -75,14 +75,14 @@ class GigaSpeechAsrDataModule:
group.add_argument( group.add_argument(
"--max-duration", "--max-duration",
type=int, type=int,
default=600.0, default=200.0,
help="Maximum pooled recordings duration (seconds) in a " help="Maximum pooled recordings duration (seconds) in a "
"single batch. You can reduce it if it causes CUDA OOM.", "single batch. You can reduce it if it causes CUDA OOM.",
) )
group.add_argument( group.add_argument(
"--bucketing-sampler", "--bucketing-sampler",
type=str2bool, type=str2bool,
default=True, default=False,
help="When enabled, the batches will come from buckets of " help="When enabled, the batches will come from buckets of "
"similar duration (saves padding frames).", "similar duration (saves padding frames).",
) )
@ -179,6 +179,12 @@ class GigaSpeechAsrDataModule:
default="XL", default="XL",
help="Select the GigaSpeech subset (XS|S|M|L|XL)", help="Select the GigaSpeech subset (XS|S|M|L|XL)",
) )
group.add_argument(
"--lazy-load",
type=str2bool,
default=True,
help="lazily open CutSets to avoid OOM (for L|XL subset)",
)
group.add_argument( group.add_argument(
"--small-dev", "--small-dev",
type=str2bool, type=str2bool,
@ -354,9 +360,12 @@ class GigaSpeechAsrDataModule:
@lru_cache() @lru_cache()
def train_cuts(self) -> CutSet: def train_cuts(self) -> CutSet:
logging.info(f"About to get train_{self.args.subset} cuts") logging.info(f"About to get train_{self.args.subset} cuts")
return load_manifest( path = self.args.manifest_dir / f"cuts_{self.args.subset}.jsonl.gz"
self.args.manifest_dir / f"cuts_{self.args.subset}.jsonl.gz" if self.args.subset in ["L", "XL"] and self.args.lazy_load:
) cuts_train = CutSet.from_jsonl_lazy(path)
else:
cuts_train = CutSet.from_file(path)
return cuts_train
@lru_cache() @lru_cache()
def dev_cuts(self) -> CutSet: def dev_cuts(self) -> CutSet:

View File

@ -625,7 +625,7 @@ def run(rank, world_size, args):
train_cuts = GigaSpeech.train_cuts() train_cuts = GigaSpeech.train_cuts()
train_dl = GigaSpeech.train_dataloaders(train_cuts) train_dl = GigaSpeech.train_dataloaders(train_cuts)
valid_cuts = GigaSpeech.dev_clean_cuts() valid_cuts = GigaSpeech.dev_cuts()
valid_dl = GigaSpeech.valid_dataloaders(valid_cuts) valid_dl = GigaSpeech.valid_dataloaders(valid_cuts)
scan_pessimistic_batches_for_oom( scan_pessimistic_batches_for_oom(