From bea78f609445a407db2377304818da550268d79c Mon Sep 17 00:00:00 2001 From: wgb14 Date: Fri, 17 Dec 2021 00:38:52 -0500 Subject: [PATCH] lazy loading and use SingleCutSampler --- .../ASR/conformer_ctc/asr_datamodule.py | 19 ++++++++++++++----- egs/gigaspeech/ASR/conformer_ctc/train.py | 2 +- 2 files changed, 15 insertions(+), 6 deletions(-) diff --git a/egs/gigaspeech/ASR/conformer_ctc/asr_datamodule.py b/egs/gigaspeech/ASR/conformer_ctc/asr_datamodule.py index b2d726adf..d29195ad2 100644 --- a/egs/gigaspeech/ASR/conformer_ctc/asr_datamodule.py +++ b/egs/gigaspeech/ASR/conformer_ctc/asr_datamodule.py @@ -75,14 +75,14 @@ class GigaSpeechAsrDataModule: group.add_argument( "--max-duration", type=int, - default=600.0, + default=200.0, help="Maximum pooled recordings duration (seconds) in a " "single batch. You can reduce it if it causes CUDA OOM.", ) group.add_argument( "--bucketing-sampler", type=str2bool, - default=True, + default=False, help="When enabled, the batches will come from buckets of " "similar duration (saves padding frames).", ) @@ -179,6 +179,12 @@ class GigaSpeechAsrDataModule: default="XL", help="Select the GigaSpeech subset (XS|S|M|L|XL)", ) + group.add_argument( + "--lazy-load", + type=str2bool, + default=True, + help="lazily open CutSets to avoid OOM (for L|XL subset)", + ) group.add_argument( "--small-dev", type=str2bool, @@ -354,9 +360,12 @@ class GigaSpeechAsrDataModule: @lru_cache() def train_cuts(self) -> CutSet: logging.info(f"About to get train_{self.args.subset} cuts") - return load_manifest( - self.args.manifest_dir / f"cuts_{self.args.subset}.jsonl.gz" - ) + path = self.args.manifest_dir / f"cuts_{self.args.subset}.jsonl.gz" + if self.args.subset in ["L", "XL"] and self.args.lazy_load: + cuts_train = CutSet.from_jsonl_lazy(path) + else: + cuts_train = CutSet.from_file(path) + return cuts_train @lru_cache() def dev_cuts(self) -> CutSet: diff --git a/egs/gigaspeech/ASR/conformer_ctc/train.py b/egs/gigaspeech/ASR/conformer_ctc/train.py index 7642e842c..adfcbc820 100755 --- a/egs/gigaspeech/ASR/conformer_ctc/train.py +++ b/egs/gigaspeech/ASR/conformer_ctc/train.py @@ -625,7 +625,7 @@ def run(rank, world_size, args): train_cuts = GigaSpeech.train_cuts() train_dl = GigaSpeech.train_dataloaders(train_cuts) - valid_cuts = GigaSpeech.dev_clean_cuts() + valid_cuts = GigaSpeech.dev_cuts() valid_dl = GigaSpeech.valid_dataloaders(valid_cuts) scan_pessimistic_batches_for_oom(