diff --git a/egs/librispeech/ASR/conformer_ctc/gigaspeech_datamodule.py b/egs/librispeech/ASR/conformer_ctc/gigaspeech_datamodule.py index 1813a4963..ee3b62a36 100644 --- a/egs/librispeech/ASR/conformer_ctc/gigaspeech_datamodule.py +++ b/egs/librispeech/ASR/conformer_ctc/gigaspeech_datamodule.py @@ -2,13 +2,14 @@ # Apache 2.0 import argparse import logging +import warnings from functools import lru_cache from pathlib import Path from typing import List, Union from torch.utils.data import DataLoader -from lhotse import CutSet, Fbank, FbankConfig, load_manifest +from lhotse import CutSet, KaldifeatFbank, FbankConfig, load_manifest from lhotse.dataset import ( BucketingSampler, CutConcatenate, @@ -24,11 +25,17 @@ from icefall.utils import str2bool from icefall.dataset.datamodule import DataModule -def get_context_suffix(args): - if args.giga_context_window is None or args.giga_context_window <= 0.0: - ctx_suffix = "" +def get_context_suffix(args, subparser=True): + if subparser: + if args.giga_context_window is None or args.giga_context_window <= 0.0: + ctx_suffix = "" + else: + ctx_suffix = f"_{args.giga_context_direction}{args.giga_context_window}" else: - ctx_suffix = f"_{args.giga_context_direction}{args.giga_context_window}" + if args.context_window is None or args.context_window <= 0.0: + ctx_suffix = "" + else: + ctx_suffix = f"_{args.context_direction}{args.context_window}" return ctx_suffix @@ -36,13 +43,14 @@ class GigaSpeechAsrDataModule(DataModule): """ DataModule for K2 ASR experiments. It assumes there is always one train and valid dataloader, - but there can be multiple test dataloaders (e.g. LibriSpeech test-clean and test-other). + It contains all the common data pipeline modules used in ASR experiments, e.g.: - dynamic batch size, - bucketing samplers, - cut concatenation, - augmentation, - on-the-fly feature extraction + This class should be derived for specific corpora used in ASR tasks. """ @@ -57,83 +65,103 @@ class GigaSpeechAsrDataModule(DataModule): parser = subparsers.add_parser(name='giga') super().add_arguments(parser) group = parser.add_argument_group( - title='ASR data related options', - description='These options are used for the preparation of PyTorch DataLoaders ' - 'from Lhotse CutSet\'s -- they control the effective batch sizes, ' - 'sampling strategies, applied data augmentations, etc.' + title="ASR data related options", + description="These options are used for the preparation of PyTorch DataLoaders " + "from Lhotse CutSet's -- they control the effective batch sizes, " + "sampling strategies, applied data augmentations, etc.", ) group.add_argument( - '--feature-dir', + "--feature-dir", dest="giga_feature_dir", type=Path, default=Path('exp/giga_data'), - help='Path to directory with train/valid/test cuts.' + help="Path to directory with train/valid/test cuts.", ) group.add_argument( - '--max-duration', + "--max-duration", dest="giga_max_duration", type=int, default=500.0, - help="Maximum pooled recordings duration (seconds) in a single batch.") + help="Maximum pooled recordings duration (seconds) in a single batch.", + ) group.add_argument( - '--bucketing-sampler', + "--bucketing-sampler", dest="giga_bucketing_sampler", type=str2bool, default=False, - help='When enabled, the batches will come from buckets of ' - 'similar duration (saves padding frames).') + help="When enabled, the batches will come from buckets of " + "similar duration (saves padding frames).", + ) group.add_argument( - '--num-buckets', + "--num-buckets", + dest="giga_num_buckets", type=int, default=30, - dest="giga_num_buckets", - help='The number of buckets for the BucketingSampler' - '(you might want to increase it for larger datasets).') + help="The number of buckets for the BucketingSampler" + "(you might want to increase it for larger datasets).", + ) group.add_argument( - '--concatenate-cuts', + "--concatenate-cuts", dest="giga_concatenate_cuts", type=str2bool, default=True, - help='When enabled, utterances (cuts) will be concatenated ' - 'to minimize the amount of padding.') + help="When enabled, utterances (cuts) will be concatenated " + "to minimize the amount of padding.", + ) group.add_argument( - '--duration-factor', + "--duration-factor", dest="giga_duration_factor", type=float, default=1.0, - help='Determines the maximum duration of a concatenated cut ' - 'relative to the duration of the longest cut in a batch.') + help="Determines the maximum duration of a concatenated cut " + "relative to the duration of the longest cut in a batch.", + ) group.add_argument( - '--gap', + "--gap", dest="giga_gap", type=float, default=1.0, - help='The amount of padding (in seconds) inserted between concatenated cuts. ' - 'This padding is filled with noise when noise augmentation is used.') + help="The amount of padding (in seconds) inserted between concatenated cuts. " + "This padding is filled with noise when noise augmentation is used.", + ) group.add_argument( - '--on-the-fly-feats', + "--on-the-fly-feats", dest="giga_on_the_fly_feats", type=str2bool, default=False, - help='When enabled, use on-the-fly cut mixing and feature extraction. ' - 'Will drop existing precomputed feature manifests if available.' + help="When enabled, use on-the-fly cut mixing and feature extraction. " + "Will drop existing precomputed feature manifests if available.", ) group.add_argument( - '--shuffle', + "--shuffle", dest="giga_shuffle", type=str2bool, default=True, - help='When enabled (=default), the examples will be shuffled for each epoch.' - ) + help="When enabled (=default), the examples will be shuffled for each epoch.", + ) group.add_argument( - '--check-cuts', - dest="giga_check_cuts", + "--return-cuts", + dest="giga_return_cuts", type=str2bool, default=True, - help='When enabled (=default), we will iterate over the whole training cut set ' - 'to validate it. It should be disabled when using Apache Arrow manifests ' - 'to avoid an excessive starting time of the script with datasets>1000h.' - ) + help="When enabled, each batch will have the field: batch['supervisions']['cut']" + " with the cuts that were used to construct it.", + ) + group.add_argument( + "--num-workers", + dest="giga_num_workers", + type=int, + default=4, + help="The number of training dataloader workers that collect the batches.", + ) + group.add_argument( + "--num-workers-inner", + dest="giga_num_workers_inner", + type=int, + default=16, + help="The number of sub-workers (replicated for each of training dataloader" + " workers) that parallelize the I/O to collect each batch.", + ) # GigaSpeech specific arguments group.add_argument( @@ -162,35 +190,36 @@ class GigaSpeechAsrDataModule(DataModule): "to seek for extra acoustic context. Available values: (left|right|center|random).", ) group.add_argument( - '--use-context-for-test', + "--use-context-for-test", dest="giga_use_context_for_text", type=str2bool, default=False, - help='Should we read cuts with acoustic context or without it. ' - '(note: for now, they may contain duplicated segments)' + help="Should we read cuts with acoustic context or without it. " + "(note: for now, they may contain duplicated segments)", ) group.add_argument( - '--small-dev', + "--small-dev", dest="giga_small_dev", type=str2bool, default=False, - help='Should we use only 1000 utterances for dev (speeds up training)' + help="Should we use only 1000 utterances for dev (speeds up training)", ) def validate_args(self): - if self.args.giga_subset in ['L', 'XL']: + if self.args.giga_subset in ["L", "XL"]: assert ( - self.args.giga_shuffle == False + self.args.giga_shuffle == False ), "For GigaSpeech L/XL, you must use --shuffle 0 to avoid eagerly reading pyarrow manifests." assert ( - self.args.giga_check_cuts == False - ), "For GigaSpeech L/XL, you must use --check-cuts 0 to avoid eagerly reading pyarrow manifests." - assert ( - self.args.giga_bucketing_sampler == False + self.args.giga_bucketing_sampler == False ), "For GigaSpeech L/XL, you must use --bucketing-sampler 0 to avoid eagerly reading pyarrow manifests." - assert ( - self.args.giga_on_the_fly_feats == True - ), "For GigaSpeech L/XL, you must use --on-the-fly-feats 1 as we do not pre-compute them by default." + # compute_and_store_features_batch is efficient for L/XL subsets. + # if not self.args.giga_on_the_fly_feats: + # warnings.warn( + # "For GigaSpeech L/XL, we advise to set --on-the-fly-feats 1," + # " as we do not pre-compute them by default. If you pre-computed them," + # " ignore this warning." + # ) def train_dataloaders(self) -> DataLoader: self.validate_args() @@ -200,27 +229,26 @@ class GigaSpeechAsrDataModule(DataModule): self.consumed_cuts = 0 logging.info("About to get Musan cuts") - cuts_musan = load_manifest(self.args.giga_feature_dir / 'cuts_musan.json.gz') + cuts_musan = load_manifest(self.args.giga_feature_dir / "cuts_musan.json.gz") logging.info("About to create train dataset") transforms = [CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20))] if self.args.giga_concatenate_cuts: - logging.info(f'Using cut concatenation with duration factor ' - f'{self.args.giga_duration_factor} and gap {self.args.giga_gap}.') + logging.info( + f"Using cut concatenation with duration factor " + f"{self.args.giga_duration_factor} and gap {self.args.giga_gap}." + ) # Cut concatenation should be the first transform in the list, # so that if we e.g. mix noise in, it will fill the gaps between different utterances. transforms = [ CutConcatenate( - duration_factor=self.args.giga_duration_factor, - gap=self.args.giga_gap + duration_factor=self.args.giga_duration_factor, gap=self.args.giga_gap ) ] + transforms train = K2SpeechRecognitionDataset( - # cuts_train, cut_transforms=transforms, - return_cuts=True, - # check_inputs=self.args.giga_check_cuts, + return_cuts=self.args.giga_return_cuts, ) if self.args.giga_on_the_fly_feats: @@ -231,75 +259,71 @@ class GigaSpeechAsrDataModule(DataModule): # # but in principle the transforms order doesn't have to be strict (e.g. could be randomized) # transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2 / 3)] + transforms train = K2SpeechRecognitionDataset( - cuts=cuts_train, cut_transforms=transforms, - input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)), num_workers=20), - return_cuts=True, - # check_inputs=self.args.giga_check_cuts, + input_strategy=OnTheFlyFeatures( + KaldifeatFbank(FbankConfig(num_mel_bins=80)), + num_workers=self.args.giga_num_workers_inner, + ), + return_cuts=self.args.giga_return_cuts, ) if self.args.giga_bucketing_sampler: - logging.info('Using BucketingSampler.') + logging.info("Using BucketingSampler.") train_sampler = BucketingSampler( cuts_train, max_duration=self.args.giga_max_duration, shuffle=self.args.giga_shuffle, - num_buckets=self.args.giga_num_buckets + num_buckets=self.args.giga_num_buckets, ) else: - logging.info('Using SingleCutSampler.') + logging.info("Using SingleCutSampler.") train_sampler = SingleCutSampler( cuts_train, max_duration=self.args.giga_max_duration, shuffle=self.args.giga_shuffle, ) logging.info("About to create train dataloader") - #train_dl = DataLoader( + # train_dl = DataLoader( # train, # sampler=train_sampler, # batch_size=None, # num_workers=16, # persistent_workers=True, - #) + # ) train_dl = LhotseDataLoader( train, sampler=train_sampler, - num_workers=3, + num_workers=self.args.giga_num_workers, prefetch_factor=5, ) return train_dl - def inexhaustible_train_dataloaders(self): - return self - def valid_dataloaders(self) -> DataLoader: self.validate_args() logging.info("About to get dev cuts") cuts_valid = self.valid_cuts() - transforms = [ ] + transforms = [] if self.args.giga_concatenate_cuts: - transforms = [ CutConcatenate( - duration_factor=self.args.giga_duration_factor, - gap=self.args.giga_gap) - ] + transforms - + transforms = [ + CutConcatenate( + duration_factor=self.args.giga_duration_factor, gap=self.args.giga_gap + ) + ] + transforms logging.info("About to create dev dataset") if self.args.giga_on_the_fly_feats: validate = K2SpeechRecognitionDataset( - cuts_valid, cut_transforms=transforms, - input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)), num_workers=8), - return_cuts=True, - check_inputs=self.args.giga_check_cuts, + input_strategy=OnTheFlyFeatures( + KaldifeatFbank(FbankConfig(num_mel_bins=80)), num_workers=8 + ), + return_cuts=self.args.giga_return_cuts, ) else: validate = K2SpeechRecognitionDataset( - # cuts_valid, cut_transforms=transforms, - return_cuts=True, - # check_inputs=self.args.giga_check_cuts, + return_cuts=self.args.giga_return_cuts, ) valid_sampler = SingleCutSampler( cuts_valid, @@ -307,13 +331,13 @@ class GigaSpeechAsrDataModule(DataModule): shuffle=False, ) logging.info("About to create dev dataloader") - #valid_dl = DataLoader( + # valid_dl = DataLoader( # validate, # sampler=valid_sampler, # batch_size=None, # num_workers=8, # persistent_workers=True, - #) + # ) valid_dl = LhotseDataLoader( validate, sampler=valid_sampler, @@ -332,18 +356,16 @@ class GigaSpeechAsrDataModule(DataModule): for cuts_test in cuts: logging.debug("About to create test dataset") test = K2SpeechRecognitionDataset( - cuts_test, input_strategy=( - OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)), num_workers=8) + OnTheFlyFeatures(KaldifeatFbank(FbankConfig(num_mel_bins=80)), num_workers=8) if self.args.giga_on_the_fly_feats else PrecomputedFeatures() ), - return_cuts=True, - check_inputs=self.args.giga_check_cuts, + return_cuts=self.args.giga_return_cuts, ) sampler = SingleCutSampler(cuts_test, max_duration=self.args.giga_max_duration) logging.debug("About to create test dataloader") - #test_dl = DataLoader(test, batch_size=None, sampler=sampler, num_workers=1) + # test_dl = DataLoader(test, batch_size=None, sampler=sampler, num_workers=1) test_dl = LhotseDataLoader(test, sampler=sampler, num_workers=2) test_loaders.append(test_dl) @@ -355,18 +377,30 @@ class GigaSpeechAsrDataModule(DataModule): @lru_cache() def train_cuts(self) -> CutSet: logging.info("About to get train cuts") - # Note: for L and XL subsets, we are expecting that the training manifest is stored using pyarrow and pre-shuffled. - cuts_path_ext = 'jsonl.gz' if self.args.giga_subset not in ['L', 'XL'] else 'arrow' - cuts_train = CutSet.from_file( - self.args.giga_feature_dir - / f"gigaspeech_cuts_{self.args.giga_subset}{get_context_suffix(self.args)}.{cuts_path_ext}" + path = ( + self.args.giga_feature_dir + / f"gigaspeech_cuts_{self.args.giga_subset}{get_context_suffix(self.args)}.jsonl.gz" ) + if self.args.giga_subset in ["L", "XL"]: + # "L" and "XL" partitions are large enough that we have to read their manifests lazily; + # The "CutSet" holds a file handle and reads the items sequentially on-the-fly to avoid + # wasting memory and time pre-reading everything. Some operations on "CutSet" won't work, + # e.g. shuffling (or they would have read everything into memory in the process). + # We expect that the manifests read lazily are pre-shuffled, otherwise you might experience + # issues with convergence. + cuts_train = CutSet.from_jsonl_lazy(path) + else: + # For other subsets, just read everything into memory. + cuts_train = CutSet.from_file(path) return cuts_train @lru_cache() def valid_cuts(self) -> CutSet: if self.args.giga_use_context_for_test: - path = self.args.giga_feature_dir / f"gigaspeech_cuts_DEV{get_context_suffix(self.args)}.jsonl.gz" + path = ( + self.args.giga_feature_dir + / f"gigaspeech_cuts_DEV{get_context_suffix(self.args)}.jsonl.gz" + ) else: path = self.args.giga_feature_dir / f"gigaspeech_cuts_DEV.jsonl.gz" logging.info(f"About to get valid cuts from {path}") @@ -379,7 +413,10 @@ class GigaSpeechAsrDataModule(DataModule): @lru_cache() def test_cuts(self) -> CutSet: if self.args.giga_use_context_for_test: - path = self.args.giga_feature_dir / f"gigaspeech_cuts_TEST{get_context_suffix(self.args)}.jsonl.gz" + path = ( + self.args.giga_feature_dir + / f"gigaspeech_cuts_TEST{get_context_suffix(self.args)}.jsonl.gz" + ) else: path = self.args.giga_feature_dir / f"gigaspeech_cuts_TEST.jsonl.gz" logging.info(f"About to get test cuts from {path}") diff --git a/egs/librispeech/ASR/prepare_gigaspeech.py b/egs/librispeech/ASR/prepare_gigaspeech.py index 92daebdc8..22b5aab30 100755 --- a/egs/librispeech/ASR/prepare_gigaspeech.py +++ b/egs/librispeech/ASR/prepare_gigaspeech.py @@ -22,7 +22,6 @@ from lhotse import ( combine, ) from lhotse.recipes import prepare_gigaspeech, prepare_musan -from lhotse.utils import is_module_available from icefall.utils import str2bool # Torch's multithreaded behavior needs to be disabled or it wastes a lot of CPU and @@ -81,7 +80,7 @@ def get_parser(): parser.add_argument( "--num-jobs", type=int, - default=min(5, os.cpu_count()), + default=min(15, os.cpu_count()), help="Number of parallel jobs.", ) parser.add_argument( @@ -115,6 +114,19 @@ def get_parser(): "might currently consume excessive memory and time -- use on-the-fly feature " "extraction in the training script instead.", ) + parser.add_argument( + "--num-workers", + type=int, + default=4, + help="Number of workers for compute_and_store_features_batch.", + ) + parser.add_argument( + "--batch-duration", + type=float, + default=600.0, + help="The maximum number of audio seconds in a batch" + "for compute_and_store_features_batch.", + ) return parser @@ -139,12 +151,6 @@ def has_no_oov( def main(): args = get_parser().parse_args() dataset_parts = [args.subset, "DEV", "TEST"] - if args.subset in ["L", "XL"]: - assert is_module_available("pyarrow"), ( - "Running the GigaSpeech recipe for L and XL splits " - "currently requires installing optional dependencies: " - "'pip install pyarrow pandas'." - ) print("Parts we will prepare: ", dataset_parts) @@ -159,7 +165,7 @@ def main(): Path("/root/fangjun/data/musan"), ) - output_dir = Path("exp/data") + output_dir = Path("exp/giga_data") print("GigaSpeech manifest preparation:") gigaspeech_manifests = prepare_gigaspeech( corpus_dir=corpus_dir, @@ -174,21 +180,19 @@ def main(): corpus_dir=musan_dir, output_dir=output_dir, parts=("music", "speech", "noise") ) - ctx_suffix = get_context_suffix(args) + ctx_suffix = get_context_suffix(args, subparser=False) print("Feature extraction:") extractor = Fbank(FbankConfig(num_mel_bins=80)) with get_executor() as ex: # Initialize the executor only once. for partition, manifests in gigaspeech_manifests.items(): - # For L and XL partition we are going to store the manifest using pyarrow. - cuts_path_ext = "jsonl.gz" if partition not in ["L", "XL"] else "arrow" raw_cuts_path = output_dir / f"gigaspeech_cuts_{partition}_raw.jsonl.gz" cuts_path = ( - output_dir / f"gigaspeech_cuts_{partition}{ctx_suffix}.{cuts_path_ext}" + output_dir / f"gigaspeech_cuts_{partition}{ctx_suffix}.jsonl.gz" ) if raw_cuts_path.is_file(): - print(f"{partition} already exists - skipping checking transcript.") + print(f"{partition} already exists - skipping feature extraction.") else: # Note this step makes the recipe different than LibriSpeech: # We must filter out some utterances and remove punctuation to be consistent with Kaldi. @@ -217,7 +221,7 @@ def main(): if cuts_path.is_file(): print( - f"{partition} already exists - skipping cutting into sub-segments and feature extraction." + f"{partition} already exists - skipping cutting into sub-segments." ) else: try: @@ -241,7 +245,7 @@ def main(): context_direction=args.context_direction, ) if partition in ["L", "XL"]: - # Before storing manifests in the arrow format, we want to pre-shuffle them, + # Before storing manifests in, we want to pre-shuffle them, # as the sampler won't be able to do it later in an efficient manner. cut_set = cut_set.shuffle() @@ -252,14 +256,21 @@ def main(): # data augmentation and feature computation for long recordings yet. # Therefore, we sacrifice some storage for the ability to precompute # features on shorter chunks, without memory blow-ups. - cut_set = cut_set.compute_and_store_features( + # cut_set = cut_set.compute_and_store_features( + # extractor=extractor, + # storage_path=f"{output_dir}/feats_gigaspeech_{partition}", + # # when an executor is specified, make more partitions + # num_jobs=args.num_jobs if ex is None else 80, + # executor=ex, + # ) + cut_set = cut_set.compute_and_store_features_batch( extractor=extractor, storage_path=f"{output_dir}/feats_gigaspeech_{partition}", - # when an executor is specified, make more partitions - num_jobs=args.num_jobs if ex is None else 80, - executor=ex, + batch_duration=args.batch_duration, + num_workers=args.num_workers, ) + cut_set.to_file(cuts_path) # Remove cut_set so the next iteration can correctly infer whether it needs to @@ -278,13 +289,19 @@ def main(): ) .cut_into_windows(10.0) .filter(lambda c: c.duration > 5) - .compute_and_store_features( + .compute_and_store_features_batch( extractor=extractor, storage_path=f"{output_dir}/feats_musan", - num_jobs=args.num_jobs if ex is None else 80, - executor=ex, - storage_type=LilcomHdf5Writer, + batch_duration=args.batch_duration, + num_workers=args.num_workers, ) + # .compute_and_store_features( + # extractor=extractor, + # storage_path=f"{output_dir}/feats_musan", + # num_jobs=args.num_jobs if ex is None else 80, + # executor=ex, + # storage_type=LilcomHdf5Writer, + # ) ) musan_cuts.to_file(musan_cuts_path)