diff --git a/egs/librispeech/ASR/local/compute_fbank_librispeech.py b/egs/librispeech/ASR/local/compute_fbank_librispeech.py index 25d6050bb..ce29ddfdd 100755 --- a/egs/librispeech/ASR/local/compute_fbank_librispeech.py +++ b/egs/librispeech/ASR/local/compute_fbank_librispeech.py @@ -68,6 +68,13 @@ def get_args(): help="""Perturb speed with factor 0.9 and 1.1 on train subset.""", ) + parser.add_argument( + "--num-workers", + type=int, + default=15, + help="Number of worker processes for feature extraction.", + ) + return parser.parse_args() @@ -75,10 +82,11 @@ def compute_fbank_librispeech( bpe_model: Optional[str] = None, dataset: Optional[str] = None, perturb_speed: Optional[bool] = True, + num_workers: int = 15, ): src_dir = Path("data/manifests") output_dir = Path("data/fbank") - num_jobs = min(15, os.cpu_count()) + num_jobs = min(num_workers, os.cpu_count()) num_mel_bins = 80 if bpe_model: @@ -125,6 +133,7 @@ def compute_fbank_librispeech( logging.info(f"{partition} already exists - skipping.") continue logging.info(f"Processing {partition}") + cut_set = CutSet.from_manifests( recordings=m["recordings"], supervisions=m["supervisions"], @@ -134,20 +143,44 @@ def compute_fbank_librispeech( if bpe_model: cut_set = filter_cuts(cut_set, sp) if perturb_speed: - logging.info(f"Doing speed perturb") + logging.info("Doing speed perturb") cut_set = ( cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1) ) - cut_set = cut_set.compute_and_store_features( - extractor=extractor, - storage_path=f"{output_dir}/{prefix}_feats_{partition}", - # when an executor is specified, make more partitions - num_jobs=num_jobs if ex is None else 80, - executor=ex, - storage_type=LilcomChunkyWriter, - ) + + if ex is None: + # Create a custom process pool context for None (local execution) + import multiprocessing as mp + from concurrent.futures import ProcessPoolExecutor + + # Calculate the number of jobs + actual_jobs = ( + min(num_jobs * 2, 20) if "train" in partition else num_jobs + ) + + # Use the forkserver method + ctx = mp.get_context("forkserver") + with ProcessPoolExecutor( + max_workers=actual_jobs, mp_context=ctx + ) as local_executor: + cut_set = cut_set.compute_and_store_features( + extractor=extractor, + storage_path=f"{output_dir}/{prefix}_feats_{partition}", + executor=local_executor, + storage_type=LilcomChunkyWriter, + ) + else: + # Distributed environment, use the provided executor + cut_set = cut_set.compute_and_store_features( + extractor=extractor, + storage_path=f"{output_dir}/{prefix}_feats_{partition}", + num_jobs=min(num_jobs * 2, 20), + executor=ex, + storage_type=LilcomChunkyWriter, + ) + cut_set.to_file(output_dir / cuts_filename) @@ -161,4 +194,5 @@ if __name__ == "__main__": bpe_model=args.bpe_model, dataset=args.dataset, perturb_speed=args.perturb_speed, + num_workers=args.num_workers, ) diff --git a/egs/librispeech/ASR/prepare.sh b/egs/librispeech/ASR/prepare.sh index cf3dc9adb..b646081ec 100755 --- a/egs/librispeech/ASR/prepare.sh +++ b/egs/librispeech/ASR/prepare.sh @@ -139,7 +139,7 @@ if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then log "Stage 3: Compute fbank for librispeech" mkdir -p data/fbank if [ ! -e data/fbank/.librispeech.done ]; then - ./local/compute_fbank_librispeech.py + ./local/compute_fbank_librispeech.py --num-workers $nj touch data/fbank/.librispeech.done fi