diff --git a/egs/librispeech/ASR/conformer_ctc/gigaspeech_datamodule.py b/egs/librispeech/ASR/conformer_ctc/gigaspeech_datamodule.py new file mode 100644 index 000000000..0698154ea --- /dev/null +++ b/egs/librispeech/ASR/conformer_ctc/gigaspeech_datamodule.py @@ -0,0 +1,449 @@ +# Copyright (c) 2021 Johns Hopkins University (Piotr Żelasko) +# Apache 2.0 +import argparse +import logging +import warnings +from functools import lru_cache +from pathlib import Path +from typing import List, Union + +from torch.utils.data import DataLoader + +from lhotse import CutSet, KaldifeatFbank, KaldifeatFbankConfig, load_manifest +from lhotse.dataset import ( + BucketingSampler, + CutConcatenate, + CutMix, + K2SpeechRecognitionDataset, + PrecomputedFeatures, + SingleCutSampler, + SpecAugment, +) +from lhotse.dataset.dataloading import LhotseDataLoader +from lhotse.dataset.input_strategies import OnTheFlyFeatures +from icefall.utils import str2bool +from icefall.dataset.datamodule import DataModule + + +def get_context_suffix(args, subparser=True): + if subparser: + if args.giga_context_window is None or args.giga_context_window <= 0.0: + ctx_suffix = "" + else: + ctx_suffix = f"_{args.giga_context_direction}{args.giga_context_window}" + else: + if args.context_window is None or args.context_window <= 0.0: + ctx_suffix = "" + else: + ctx_suffix = f"_{args.context_direction}{args.context_window}" + return ctx_suffix + + +class GigaSpeechAsrDataModule(DataModule): + """ + DataModule for K2 ASR experiments. + It assumes there is always one train and valid dataloader, + + It contains all the common data pipeline modules used in ASR experiments, e.g.: + - dynamic batch size, + - bucketing samplers, + - cut concatenation, + - augmentation, + - on-the-fly feature extraction + + This class should be derived for specific corpora used in ASR tasks. + """ + + def __init__(self, args): + self.total_train_cuts = 0 + self.consumed_cuts = 0 + self.args = args + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser): + subparsers = parser.add_subparsers(help='seperate gigaspeech arguments from librispeech arguments') + parser = subparsers.add_parser(name='giga') + super().add_arguments(parser) + group = parser.add_argument_group( + title="ASR data related options", + description="These options are used for the preparation of PyTorch DataLoaders " + "from Lhotse CutSet's -- they control the effective batch sizes, " + "sampling strategies, applied data augmentations, etc.", + ) + group.add_argument( + "--feature-dir", + dest="giga_feature_dir", + type=Path, + default=Path('exp/giga_data'), + help="Path to directory with train/valid/test cuts.", + ) + group.add_argument( + "--max-duration", + dest="giga_max_duration", + type=int, + default=500.0, + help="Maximum pooled recordings duration (seconds) in a single batch.", + ) + group.add_argument( + "--bucketing-sampler", + dest="giga_bucketing_sampler", + type=str2bool, + default=False, + help="When enabled, the batches will come from buckets of " + "similar duration (saves padding frames).", + ) + group.add_argument( + "--num-buckets", + dest="giga_num_buckets", + type=int, + default=30, + help="The number of buckets for the BucketingSampler" + "(you might want to increase it for larger datasets).", + ) + group.add_argument( + "--concatenate-cuts", + dest="giga_concatenate_cuts", + type=str2bool, + default=True, + help="When enabled, utterances (cuts) will be concatenated " + "to minimize the amount of padding.", + ) + group.add_argument( + "--duration-factor", + dest="giga_duration_factor", + type=float, + default=1.0, + help="Determines the maximum duration of a concatenated cut " + "relative to the duration of the longest cut in a batch.", + ) + group.add_argument( + "--gap", + dest="giga_gap", + type=float, + default=1.0, + help="The amount of padding (in seconds) inserted between concatenated cuts. " + "This padding is filled with noise when noise augmentation is used.", + ) + group.add_argument( + "--on-the-fly-feats", + dest="giga_on_the_fly_feats", + type=str2bool, + default=False, + help="When enabled, use on-the-fly cut mixing and feature extraction. " + "Will drop existing precomputed feature manifests if available.", + ) + group.add_argument( + "--shuffle", + dest="giga_shuffle", + type=str2bool, + default=True, + help="When enabled (=default), the examples will be shuffled for each epoch.", + ) + group.add_argument( + "--return-cuts", + dest="giga_return_cuts", + type=str2bool, + default=True, + help="When enabled, each batch will have the field: batch['supervisions']['cut']" + " with the cuts that were used to construct it.", + ) + group.add_argument( + "--num-workers", + dest="giga_num_workers", + type=int, + default=4, + help="The number of training dataloader workers that collect the batches.", + ) + group.add_argument( + "--num-workers-inner", + dest="giga_num_workers_inner", + type=int, + default=16, + help="The number of sub-workers (replicated for each of training dataloader" + " workers) that parallelize the I/O to collect each batch.", + ) + + # GigaSpeech specific arguments + group.add_argument( + "--subset", + dest="giga_subset", + type=str, + default="XS", + help="Select the GigaSpeech subset (XS|S|M|L|XL)", + ) + group.add_argument( + "--context-window", + dest="giga_context_window", + type=float, + default=0.0, + help="Training cut duration in seconds. " + "Use 0 to train on supervision segments without acoustic context, with variable cut lengths; " + "number larger than zero will create multi-supervisions cuts with actual acoustic context. ", + ) + group.add_argument( + "--context-direction", + dest="giga_context_direction", + type=str, + default="center", + help="If context-window is 0, does nothing. " + "If it's larger than 0, determines in which direction (relative to the supervision) " + "to seek for extra acoustic context. Available values: (left|right|center|random).", + ) + group.add_argument( + "--use-context-for-test", + dest="giga_use_context_for_text", + type=str2bool, + default=False, + help="Should we read cuts with acoustic context or without it. " + "(note: for now, they may contain duplicated segments)", + ) + group.add_argument( + "--small-dev", + dest="giga_small_dev", + type=str2bool, + default=False, + help="Should we use only 1000 utterances for dev (speeds up training)", + ) + + def validate_args(self): + if self.args.giga_subset in ["L", "XL"]: + assert ( + self.args.giga_shuffle == False + ), "For GigaSpeech L/XL, you must use --shuffle 0 to avoid eagerly reading pyarrow manifests." + assert ( + self.args.giga_bucketing_sampler == False + ), "For GigaSpeech L/XL, you must use --bucketing-sampler 0 to avoid eagerly reading pyarrow manifests." + # compute_and_store_features_batch is efficient for L/XL subsets. + # if not self.args.giga_on_the_fly_feats: + # warnings.warn( + # "For GigaSpeech L/XL, we advise to set --on-the-fly-feats 1," + # " as we do not pre-compute them by default. If you pre-computed them," + # " ignore this warning." + # ) + + def train_dataloaders(self) -> DataLoader: + self.validate_args() + logging.info("About to get train cuts") + cuts_train = self.train_cuts() + self.total_train_cuts = len(cuts_train) + self.consumed_cuts = 0 + + logging.info("About to get Musan cuts") + cuts_musan = load_manifest(self.args.giga_feature_dir / "cuts_musan.json.gz") + + logging.info("About to create train dataset") + transforms = [CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20))] + if self.args.giga_concatenate_cuts: + logging.info( + f"Using cut concatenation with duration factor " + f"{self.args.giga_duration_factor} and gap {self.args.giga_gap}." + ) + # Cut concatenation should be the first transform in the list, + # so that if we e.g. mix noise in, it will fill the gaps between different utterances. + transforms = [ + CutConcatenate( + duration_factor=self.args.giga_duration_factor, gap=self.args.giga_gap + ) + ] + transforms + + train = K2SpeechRecognitionDataset( + cut_transforms=transforms, + return_cuts=self.args.giga_return_cuts, + ) + + if self.args.giga_on_the_fly_feats: + # NOTE: the PerturbSpeed transform should be added only if we remove it from data prep stage. + # # Add on-the-fly speed perturbation; since originally it would have increased epoch + # # size by 3, we will apply prob 2/3 and use 3x more epochs. + # # Speed perturbation probably should come first before concatenation, + # # but in principle the transforms order doesn't have to be strict (e.g. could be randomized) + # transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2 / 3)] + transforms + train = K2SpeechRecognitionDataset( + cut_transforms=transforms, + input_strategy=OnTheFlyFeatures( + # To avoid unexpected GPU OOM issue during training, + # I think using the cpu version is safer + # KaldifeatFbank(KaldifeatFbankConfig(device='cuda')), + KaldifeatFbank(KaldifeatFbankConfig()), + num_workers=self.args.giga_num_workers_inner, + ), + return_cuts=self.args.giga_return_cuts, + ) + + if self.args.giga_bucketing_sampler: + logging.info("Using BucketingSampler.") + train_sampler = BucketingSampler( + cuts_train, + max_duration=self.args.giga_max_duration, + shuffle=self.args.giga_shuffle, + num_buckets=self.args.giga_num_buckets, + ) + else: + logging.info("Using SingleCutSampler.") + train_sampler = SingleCutSampler( + cuts_train, + max_duration=self.args.giga_max_duration, + shuffle=self.args.giga_shuffle, + ) + logging.info("About to create train dataloader") + # train_dl = DataLoader( + # train, + # sampler=train_sampler, + # batch_size=None, + # num_workers=16, + # persistent_workers=True, + # ) + train_dl = LhotseDataLoader( + train, + sampler=train_sampler, + num_workers=self.args.giga_num_workers, + prefetch_factor=5, + ) + return train_dl + + def valid_dataloaders(self) -> DataLoader: + self.validate_args() + logging.info("About to get dev cuts") + cuts_valid = self.valid_cuts() + + transforms = [] + if self.args.giga_concatenate_cuts: + transforms = [ + CutConcatenate( + duration_factor=self.args.giga_duration_factor, gap=self.args.giga_gap + ) + ] + transforms + + logging.info("About to create dev dataset") + if self.args.giga_on_the_fly_feats: + validate = K2SpeechRecognitionDataset( + cut_transforms=transforms, + input_strategy=OnTheFlyFeatures( + # To avoid unexpected GPU OOM issue during training, + # I think using the cpu version is safer + # KaldifeatFbank(KaldifeatFbankConfig(device='cuda')), num_workers=8 + KaldifeatFbank(KaldifeatFbankConfig()), num_workers=8 + ), + return_cuts=self.args.giga_return_cuts, + ) + else: + validate = K2SpeechRecognitionDataset( + cut_transforms=transforms, + return_cuts=self.args.giga_return_cuts, + ) + valid_sampler = SingleCutSampler( + cuts_valid, + max_duration=self.args.giga_max_duration, + shuffle=False, + ) + logging.info("About to create dev dataloader") + # valid_dl = DataLoader( + # validate, + # sampler=valid_sampler, + # batch_size=None, + # num_workers=8, + # persistent_workers=True, + # ) + valid_dl = LhotseDataLoader( + validate, + sampler=valid_sampler, + num_workers=2, + ) + return valid_dl + + def test_dataloaders(self) -> Union[DataLoader, List[DataLoader]]: + self.validate_args() + cuts = self.test_cuts() + is_list = isinstance(cuts, list) + test_loaders = [] + if not is_list: + cuts = [cuts] + + for cuts_test in cuts: + logging.debug("About to create test dataset") + test = K2SpeechRecognitionDataset( + input_strategy=( + # To avoid unexpected GPU OOM issue during training, + # I think using the cpu version is safer + # OnTheFlyFeatures(KaldifeatFbank(KaldifeatFbankConfig(device='cuda')), num_workers=8) + OnTheFlyFeatures(KaldifeatFbank(KaldifeatFbankConfig()), num_workers=8) + if self.args.giga_on_the_fly_feats + else PrecomputedFeatures() + ), + return_cuts=self.args.giga_return_cuts, + ) + sampler = SingleCutSampler(cuts_test, max_duration=self.args.giga_max_duration) + logging.debug("About to create test dataloader") + # test_dl = DataLoader(test, batch_size=None, sampler=sampler, num_workers=1) + test_dl = LhotseDataLoader(test, sampler=sampler, num_workers=2) + test_loaders.append(test_dl) + + if is_list: + return test_loaders + else: + return test_loaders[0] + + @lru_cache() + def train_cuts(self) -> CutSet: + logging.info("About to get train cuts") + path = ( + self.args.giga_feature_dir + / f"gigaspeech_cuts_{self.args.giga_subset}{get_context_suffix(self.args)}.jsonl.gz" + ) + if self.args.giga_subset in ["L", "XL"]: + # "L" and "XL" partitions are large enough that we have to read their manifests lazily; + # The "CutSet" holds a file handle and reads the items sequentially on-the-fly to avoid + # wasting memory and time pre-reading everything. Some operations on "CutSet" won't work, + # e.g. shuffling (or they would have read everything into memory in the process). + # We expect that the manifests read lazily are pre-shuffled, otherwise you might experience + # issues with convergence. + cuts_train = CutSet.from_jsonl_lazy(path) + else: + # For other subsets, just read everything into memory. + cuts_train = CutSet.from_file(path) + return cuts_train + + @lru_cache() + def valid_cuts(self) -> CutSet: + if self.args.giga_use_context_for_test: + path = ( + self.args.giga_feature_dir + / f"gigaspeech_cuts_DEV{get_context_suffix(self.args)}.jsonl.gz" + ) + else: + path = self.args.giga_feature_dir / f"gigaspeech_cuts_DEV.jsonl.gz" + logging.info(f"About to get valid cuts from {path}") + cuts_valid = load_manifest(path) + if self.args.giga_small_dev: + return cuts_valid.subset(first=1000) + else: + return cuts_valid + + @lru_cache() + def test_cuts(self) -> CutSet: + if self.args.giga_use_context_for_test: + path = ( + self.args.giga_feature_dir + / f"gigaspeech_cuts_TEST{get_context_suffix(self.args)}.jsonl.gz" + ) + else: + path = self.args.giga_feature_dir / f"gigaspeech_cuts_TEST.jsonl.gz" + logging.info(f"About to get test cuts from {path}") + cuts_test = load_manifest(path) + return cuts_test + + def inexhaustible_train_dataloaders(self): + return self + + def __iter__(self): + # work horse for inexhuastible_train_dataloaders + while True: + # self.total_train_cuts == 0 for the first run + # self.consumed_cuts == self.total_train_cuts for recreating dataloader + if self.total_train_cuts == 0 or self.consumed_cuts == self.total_train_cuts: + self.train_dl = self.train_dataloaders() + self.consumed_cuts = 0 + + for batch in self.train_dl: + self.consumed_cuts += len(batch["supervisions"]["text"]) + yield batch diff --git a/egs/librispeech/ASR/example_giga_dataloader.py b/egs/librispeech/ASR/example_giga_dataloader.py new file mode 100644 index 000000000..7f8aa20eb --- /dev/null +++ b/egs/librispeech/ASR/example_giga_dataloader.py @@ -0,0 +1,28 @@ +import argparse +import json +from pathlib import Path + +from gigaspeech_datamodule import GigaSpeechAsrDataModule + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + group = parser.add_argument_group(title='libri related options') + group.add_argument( + '--max-duration', + type=int, + default=500.0, + help="Maximum pooled recordings duration (seconds) in a single batch.") + return parser + +if __name__ == '__main__': + parser = get_parser() + GigaSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + gigaspeech = GigaSpeechAsrDataModule(args) + train_dl = gigaspeech.inexhaustible_train_dataloaders() + for idx, batch in enumerate(train_dl): + print(batch["inputs"].shape) + print(len(batch["supervisions"]["text"])) + print(batch["supervisions"]["text"][0:2]) diff --git a/egs/librispeech/ASR/prepare_gigaspeech.py b/egs/librispeech/ASR/prepare_gigaspeech.py new file mode 100755 index 000000000..870909787 --- /dev/null +++ b/egs/librispeech/ASR/prepare_gigaspeech.py @@ -0,0 +1,313 @@ +#!/usr/bin/env python3 + +# Copyright (c) 2021 Johns Hopkins University (Piotr Żelasko) +# Apache 2.0 +import argparse +import os +import re +import subprocess +import sys +from contextlib import contextmanager +from pathlib import Path +from functools import partial + +import torch + +from gigaspeech_datamodule import get_context_suffix +from lhotse import ( + CutSet, + KaldifeatFbank, + KaldifeatFbankConfig, + LilcomHdf5Writer, + SupervisionSegment, + combine, +) +from lhotse.recipes import prepare_gigaspeech, prepare_musan +from icefall.utils import str2bool + +# Torch's multithreaded behavior needs to be disabled or it wastes a lot of CPU and +# slow things down. Do this outside of main() in case it needs to take effect +# even when we are not invoking the main (e.g. when spawning subprocesses). +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + + +@contextmanager +def get_executor(): + # We'll either return a process pool or a distributed worker pool. + # Note that this has to be a context manager because we might use multiple + # context manager ("with" clauses) inside, and this way everything will + # free up the resources at the right time. + try: + # If this is executed on the CLSP grid, we will try to use the + # Grid Engine to distribute the tasks. + # Other clusters can also benefit from that, provided a cluster-specific wrapper. + # (see https://github.com/pzelasko/plz for reference) + # + # The following must be installed: + # $ pip install dask distributed + # $ pip install git+https://github.com/pzelasko/plz + name = subprocess.check_output("hostname -f", shell=True, text=True) + if name.strip().endswith(".clsp.jhu.edu"): + import plz + from distributed import Client + + with plz.setup_cluster() as cluster: + cluster.scale(80) + yield Client(cluster) + return + except: + pass + # No need to return anything - compute_and_store_features + # will just instantiate the pool itself. + yield None + + +def locate_corpus(*corpus_dirs): + for d in corpus_dirs: + if os.path.exists(d): + return d + print( + "Please create a place on your system to put the downloaded Librispeech data " + "and add it to `corpus_dirs`" + ) + sys.exit(1) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + parser.add_argument( + "--num-jobs", + type=int, + default=min(15, os.cpu_count()), + help="Number of parallel jobs.", + ) + parser.add_argument( + "--subset", + type=str, + default="XS", + help="Select the GigaSpeech subset (XS|S|M|L|XL)", + ) + parser.add_argument( + "--context-window", + type=float, + default=0.0, + help="Training cut duration in seconds. " + "Use 0 to train on supervision segments without acoustic context, with variable cut lengths; " + "number larger than zero will create multi-supervisions cuts with actual acoustic context. ", + ) + parser.add_argument( + "--context-direction", + type=str, + default="center", + help="If context-window is 0, does nothing. " + "If it's larger than 0, determines in which direction (relative to the supervision) " + "to seek for extra acoustic context. Available values: (left|right|center|random).", + ) + parser.add_argument( + "--precomputed-features", + type=str2bool, + default=True, + help="Should we pre-compute features and store them on disk or not. " + "It is recommended to disable it for L and XL splits as the pre-computation " + "might currently consume excessive memory and time -- use on-the-fly feature " + "extraction in the training script instead.", + ) + parser.add_argument( + "--num-workers", + type=int, + default=4, + help="Number of workers for compute_and_store_features_batch.", + ) + parser.add_argument( + "--batch-duration", + type=float, + default=600.0, + help="The maximum number of audio seconds in a batch" + "for compute_and_store_features_batch.", + ) + return parser + + +# Similar text filtering and normalization procedure as in: +# https://github.com/SpeechColab/GigaSpeech/blob/main/toolkits/kaldi/gigaspeech_data_prep.sh + + +def normalize_text( + utt: str, + punct_pattern=re.compile(r"<(COMMA|PERIOD|QUESTIONMARK|EXCLAMATIONPOINT)>"), + whitespace_pattern=re.compile(r"\s\s+"), +) -> str: + return whitespace_pattern.sub(" ", punct_pattern.sub("", utt)) + + +def has_no_oov( + sup: SupervisionSegment, oov_pattern=re.compile(r"<(SIL|MUSIC|NOISE|OTHER)>") +) -> bool: + return oov_pattern.search(sup.text) is None + + +def main(): + args = get_parser().parse_args() + dataset_parts = [args.subset, "DEV", "TEST"] + + print("Parts we will prepare: ", dataset_parts) + + corpus_dir = locate_corpus( + Path("/export/corpora5/gigaspeech"), + Path("/exp/pzelasko/gigaspeech"), + Path("/home/storage07/zhangjunbo/data/GigaSpeech") + ) + musan_dir = locate_corpus( + Path("/export/corpora5/JHU/musan"), + Path("/export/common/data/corpora/MUSAN/musan"), + Path("/root/fangjun/data/musan"), + ) + + output_dir = Path("exp/giga_data") + print("GigaSpeech manifest preparation:") + gigaspeech_manifests = prepare_gigaspeech( + corpus_dir=corpus_dir, + dataset_parts=dataset_parts, + output_dir=output_dir, + num_jobs=args.num_jobs, + ) + + print("Musan manifest preparation:") + musan_cuts_path = output_dir / "cuts_musan.json.gz" + musan_manifests = prepare_musan( + corpus_dir=musan_dir, output_dir=output_dir, parts=("music", "speech", "noise") + ) + + ctx_suffix = get_context_suffix(args, subparser=False) + + print("Feature extraction:") + # extractor = Fbank(FbankConfig(num_mel_bins=80)) + extractor = KaldifeatFbank(KaldifeatFbankConfig(device='cuda')) # default config uses 80 mel bins already + with get_executor() as ex: # Initialize the executor only once. + for partition, manifests in gigaspeech_manifests.items(): + raw_cuts_path = output_dir / f"gigaspeech_cuts_{partition}_raw.jsonl.gz" + cuts_path = ( + output_dir / f"gigaspeech_cuts_{partition}{ctx_suffix}.jsonl.gz" + ) + + if raw_cuts_path.is_file(): + print(f"{partition} already exists - skipping feature extraction.") + else: + # Note this step makes the recipe different than LibriSpeech: + # We must filter out some utterances and remove punctuation to be consistent with Kaldi. + print("Filtering OOV utterances from supervisions") + manifests["supervisions"] = manifests["supervisions"].filter(has_no_oov) + print("Normalizing text in", partition) + for sup in manifests["supervisions"]: + sup.text = normalize_text(sup.text) + + # Create long-recording cut manifests. + print("Processing", partition) + cut_set = CutSet.from_manifests( + recordings=manifests["recordings"], + supervisions=manifests["supervisions"], + ) + + # Run data augmentation that needs to be done in the time domain. + if partition not in ["DEV", "TEST"]: + cut_set = ( + cut_set + + cut_set.perturb_speed(0.9) + + cut_set.perturb_speed(1.1) + ) + + cut_set.to_file(raw_cuts_path) + + if cuts_path.is_file(): + print( + f"{partition} already exists - skipping cutting into sub-segments." + ) + else: + try: + # If we skipped initializing `cut_set` because it exists on disk, we'll load it. + # This helps us avoid re-computing the features for different variants of + # context windows. + cut_set + except NameError: + print(f"Reading {partition} raw cuts from disk.") + cut_set = CutSet.from_file(raw_cuts_path) + # Note this step makes the recipe different than LibriSpeech: + # Since recordings are long, the initial CutSet has very long cuts with a plenty of supervisions. + # We cut these into smaller chunks centered around each supervision, possibly adding acoustic + # context. + print(f"About to split {partition} raw cuts into smaller chunks.") + cut_set = cut_set.trim_to_supervisions( + keep_overlapping=False, + min_duration=None + if args.context_window <= 0.0 + else args.context_window, + context_direction=args.context_direction, + ) + if partition in ["L", "XL"]: + # Before storing manifests in, we want to pre-shuffle them, + # as the sampler won't be able to do it later in an efficient manner. + cut_set = cut_set.shuffle() + + if args.precomputed_features: + # Extract the features after cutting large recordings into smaller cuts. + # Note: we support very efficient "chunked" feature reads with the argument + # `storage_type=ChunkedLilcomHdf5Writer`, but we don't support efficient + # data augmentation and feature computation for long recordings yet. + # Therefore, we sacrifice some storage for the ability to precompute + # features on shorter chunks, without memory blow-ups. + # cut_set = cut_set.compute_and_store_features( + # extractor=extractor, + # storage_path=f"{output_dir}/feats_gigaspeech_{partition}", + # # when an executor is specified, make more partitions + # num_jobs=args.num_jobs if ex is None else 80, + # executor=ex, + # ) + cut_set = cut_set.compute_and_store_features_batch( + extractor=extractor, + storage_path=f"{output_dir}/feats_gigaspeech_{partition}", + batch_duration=args.batch_duration, + num_workers=args.num_workers, + storage_type=partial(LilcomHdf5Writer, tick_power=-3), + ) + + + cut_set.to_file(cuts_path) + + # Remove cut_set so the next iteration can correctly infer whether it needs to + # load the raw cuts from disk or not. + del cut_set + + # Now onto Musan + if not musan_cuts_path.is_file(): + print("Extracting features for Musan") + # create chunks of Musan with duration 5 - 10 seconds + musan_cuts = ( + CutSet.from_manifests( + recordings=combine( + part["recordings"] for part in musan_manifests.values() + ) + ) + .cut_into_windows(10.0) + .filter(lambda c: c.duration > 5) + .compute_and_store_features_batch( + extractor=extractor, + storage_path=f"{output_dir}/feats_musan", + batch_duration=args.batch_duration, + num_workers=args.num_workers, + ) + # .compute_and_store_features( + # extractor=extractor, + # storage_path=f"{output_dir}/feats_musan", + # num_jobs=args.num_jobs if ex is None else 80, + # executor=ex, + # storage_type=LilcomHdf5Writer, + # ) + ) + musan_cuts.to_file(musan_cuts_path) + + +if __name__ == "__main__": + main()