From 1d58765bd52b14401a27b1b91847b84548aadc3a Mon Sep 17 00:00:00 2001 From: wgb14 Date: Sat, 13 Nov 2021 17:45:35 -0500 Subject: [PATCH] on-the-fly feature extraction by default --- .gitignore | 2 + .../ASR/local/compute_fbank_gigaspeech.py | 209 +++++++++++++++--- egs/gigaspeech/ASR/prepare.sh | 3 +- 3 files changed, 183 insertions(+), 31 deletions(-) diff --git a/.gitignore b/.gitignore index f4f703243..d84ea96b4 100644 --- a/.gitignore +++ b/.gitignore @@ -6,3 +6,5 @@ exp exp*/ *.pt download +dask-worker-space +log diff --git a/egs/gigaspeech/ASR/local/compute_fbank_gigaspeech.py b/egs/gigaspeech/ASR/local/compute_fbank_gigaspeech.py index ea5d3dc6e..9a9088ebc 100755 --- a/egs/gigaspeech/ASR/local/compute_fbank_gigaspeech.py +++ b/egs/gigaspeech/ASR/local/compute_fbank_gigaspeech.py @@ -1,5 +1,5 @@ #!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# Copyright 2021 Johns Hopkins University (Piotr Żelasko) # # See ../../../../LICENSE for clarification regarding multiple authors # @@ -23,15 +23,23 @@ It looks for manifests in the directory data/manifests. The generated fbank features are saved in data/fbank. """ +import argparse import logging import os +import re from pathlib import Path import torch -from lhotse import CutSet, Fbank, FbankConfig, LilcomHdf5Writer +from lhotse import ( + CutSet, + Fbank, + FbankConfig, + LilcomHdf5Writer, + SupervisionSegment, +) from lhotse.recipes.utils import read_manifests_if_cached -from icefall.utils import get_executor +from icefall.utils import get_executor, str2bool # Torch's multithreaded behavior needs to be disabled or # it wastes a lot of CPU and slow things down. @@ -41,10 +49,76 @@ torch.set_num_threads(1) torch.set_num_interop_threads(1) -def compute_fbank_gigaspeech(): +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + parser.add_argument( + "--num-jobs", + type=int, + default=min(15, os.cpu_count()), + help="Number of parallel jobs.", + ) + parser.add_argument( + "--context-window", + type=float, + default=0.0, + help="Training cut duration in seconds. " + "Use 0 to train on supervision segments without acoustic context, " + "with variable cut lengths; number larger than zero will create " + "multi-supervisions cuts with actual acoustic context. ", + ) + parser.add_argument( + "--context-direction", + type=str, + default="center", + help="If context-window is 0, does nothing. " + "If it's larger than 0, determines in which direction " + "(relative to the supervision) to seek for extra acoustic context. " + "Available values: (left|right|center|random).", + ) + parser.add_argument( + "--precomputed-features", + type=str2bool, + default=False, + help="Should we pre-compute features and store them on disk or not. " + "It is recommended to disable it for L and XL splits as the " + "pre-computation might currently consume excessive memory and time " + "-- use on-the-fly feature extraction in the training script instead.", + ) + return parser + + +# Similar text filtering and normalization procedure as in: +# https://github.com/SpeechColab/GigaSpeech/blob/main/toolkits/kaldi/gigaspeech_data_prep.sh + + +def normalize_text( + utt: str, + punct_pattern=re.compile(r"<(COMMA|PERIOD|QUESTIONMARK|EXCLAMATIONPOINT)>"), + whitespace_pattern=re.compile(r"\s\s+"), +) -> str: + return whitespace_pattern.sub(" ", punct_pattern.sub("", utt)) + + +def has_no_oov( + sup: SupervisionSegment, + oov_pattern=re.compile(r"<(SIL|MUSIC|NOISE|OTHER)>"), +) -> bool: + return oov_pattern.search(sup.text) is None + + +def get_context_suffix(args): + if args.context_window is None or args.context_window <= 0.0: + ctx_suffix = "" + else: + ctx_suffix = f"_{args.context_direction}{args.context_window}" + return ctx_suffix + + +def compute_fbank_gigaspeech(args): src_dir = Path("data/manifests") output_dir = Path("data/fbank") - num_jobs = min(10, os.cpu_count()) num_mel_bins = 80 dataset_parts = ( @@ -61,39 +135,114 @@ def compute_fbank_gigaspeech(): assert manifests is not None extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins)) + ctx_suffix = get_context_suffix(args) with get_executor() as ex: # Initialize the executor only once. for partition, m in manifests.items(): - if (output_dir / f"cuts_{partition}.jsonl.gz").is_file(): - logging.info(f"{partition} already exists - skipping.") - continue - logging.info(f"Processing {partition}") - cut_set = CutSet.from_manifests( - recordings=m["recordings"], - supervisions=m["supervisions"], - ) - if "train" in partition: - cut_set = ( - cut_set - + cut_set.perturb_speed(0.9) - + cut_set.perturb_speed(1.1) + raw_cuts_path = output_dir / f"cuts_{partition}_raw.jsonl.gz" + if raw_cuts_path.is_file(): + logging.info( + f"{partition} already exists - skipping feature extraction." ) - cut_set = cut_set.compute_and_store_features( - extractor=extractor, - storage_path=f"{output_dir}/feats_{partition}", - # when an executor is specified, make more partitions - num_jobs=num_jobs if ex is None else 80, - executor=ex, - storage_type=LilcomHdf5Writer, - ) - cut_set.to_json(output_dir / f"cuts_{partition}.jsonl.gz") + else: + # Note this step makes the recipe different than LibriSpeech: + # We must filter out some utterances and remove punctuation + # to be consistent with Kaldi. + logging.info("Filtering OOV utterances from supervisions") + m["supervisions"] = m["supervisions"].filter(has_no_oov) + logging.info(f"Normalizing text in {partition}") + for sup in m["supervisions"]: + sup.text = normalize_text(sup.text) + + # Create long-recording cut manifests. + logging.info(f"Processing {partition}") + cut_set = CutSet.from_manifests( + recordings=m["recordings"], + supervisions=m["supervisions"], + ) + # Run data augmentation that needs to be done in the + # time domain. + if partition not in ["DEV", "TEST"]: + cut_set = ( + cut_set + + cut_set.perturb_speed(0.9) + + cut_set.perturb_speed(1.1) + ) + cut_set.to_file(raw_cuts_path) + + cuts_path = output_dir / f"cuts_{partition}{ctx_suffix}.jsonl.gz" + if cuts_path.is_file(): + logging.info( + f"{partition} already exists - skipping cutting into " + f"sub-segments." + ) + else: + try: + # If we skipped initializing `cut_set` because it exists + # on disk, we'll load it. This helps us avoid re-computing + # the features for different variants of context windows. + cut_set + except NameError: + logging.info(f"Reading {partition} raw cuts from disk.") + cut_set = CutSet.from_file(raw_cuts_path) + # Note this step makes the recipe different than LibriSpeech: + # Since recordings are long, the initial CutSet has very long + # cuts with a plenty of supervisions. We cut these into smaller + # chunks centered around each supervision, possibly adding + # acoustic context. + logging.info( + f"About to split {partition} raw cuts into smaller chunks." + ) + cut_set = cut_set.trim_to_supervisions( + keep_overlapping=False, + min_duration=None + if args.context_window <= 0.0 + else args.context_window, + context_direction=args.context_direction, + ) + if partition in ["L", "XL"]: + # Before storing manifests in, we want to pre-shuffle them, + # as the sampler won't be able to do it later in an + # efficient manner. + cut_set = cut_set.shuffle() + + if args.precomputed_features: + # Extract the features after cutting large recordings into + # smaller cuts. + # Note: + # we support very efficient "chunked" feature reads with + # the argument `storage_type=ChunkedLilcomHdf5Writer`, + # but we don't support efficient data augmentation and + # feature computation for long recordings yet. + # Therefore, we sacrifice some storage for the ability to + # precompute features on shorter chunks, + # without memory blow-ups. + cut_set = cut_set.compute_and_store_features( + extractor=extractor, + storage_path=f"{output_dir}/feats_{partition}", + # when an executor is specified, make more partitions + num_jobs=args.num_jobs if ex is None else 80, + executor=ex, + storage_type=LilcomHdf5Writer, + ) + cut_set.to_file(cuts_path) + + # Remove cut_set so the next iteration can correctly infer + # whether it needs to load the raw cuts from disk or not. + del cut_set -if __name__ == "__main__": +def main(): formatter = ( "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" ) - logging.basicConfig(format=formatter, level=logging.INFO) - compute_fbank_gigaspeech() + parser = get_parser() + args = parser.parse_args() + + compute_fbank_gigaspeech(args) + + +if __name__ == "__main__": + main() diff --git a/egs/gigaspeech/ASR/prepare.sh b/egs/gigaspeech/ASR/prepare.sh index fd6fb217f..46f99b6b2 100755 --- a/egs/gigaspeech/ASR/prepare.sh +++ b/egs/gigaspeech/ASR/prepare.sh @@ -110,7 +110,8 @@ fi if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then log "Stage 3: Compute fbank for GigaSpeech" mkdir -p data/fbank - ./local/compute_fbank_gigaspeech.py + ./local/compute_fbank_gigaspeech.py --num-jobs $nj --context-window 0.0 \ + --context-direction center --precomputed-features False fi if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then