on-the-fly feature extraction by default

This commit is contained in:
wgb14 2021-11-13 17:45:35 -05:00
parent 75860159a2
commit 1d58765bd5
3 changed files with 183 additions and 31 deletions

2
.gitignore vendored
View File

@ -6,3 +6,5 @@ exp
exp*/ exp*/
*.pt *.pt
download download
dask-worker-space
log

View File

@ -1,5 +1,5 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) # Copyright 2021 Johns Hopkins University (Piotr Żelasko)
# #
# See ../../../../LICENSE for clarification regarding multiple authors # See ../../../../LICENSE for clarification regarding multiple authors
# #
@ -23,15 +23,23 @@ It looks for manifests in the directory data/manifests.
The generated fbank features are saved in data/fbank. The generated fbank features are saved in data/fbank.
""" """
import argparse
import logging import logging
import os import os
import re
from pathlib import Path from pathlib import Path
import torch import torch
from lhotse import CutSet, Fbank, FbankConfig, LilcomHdf5Writer from lhotse import (
CutSet,
Fbank,
FbankConfig,
LilcomHdf5Writer,
SupervisionSegment,
)
from lhotse.recipes.utils import read_manifests_if_cached from lhotse.recipes.utils import read_manifests_if_cached
from icefall.utils import get_executor from icefall.utils import get_executor, str2bool
# Torch's multithreaded behavior needs to be disabled or # Torch's multithreaded behavior needs to be disabled or
# it wastes a lot of CPU and slow things down. # it wastes a lot of CPU and slow things down.
@ -41,10 +49,76 @@ torch.set_num_threads(1)
torch.set_num_interop_threads(1) torch.set_num_interop_threads(1)
def compute_fbank_gigaspeech(): def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--num-jobs",
type=int,
default=min(15, os.cpu_count()),
help="Number of parallel jobs.",
)
parser.add_argument(
"--context-window",
type=float,
default=0.0,
help="Training cut duration in seconds. "
"Use 0 to train on supervision segments without acoustic context, "
"with variable cut lengths; number larger than zero will create "
"multi-supervisions cuts with actual acoustic context. ",
)
parser.add_argument(
"--context-direction",
type=str,
default="center",
help="If context-window is 0, does nothing. "
"If it's larger than 0, determines in which direction "
"(relative to the supervision) to seek for extra acoustic context. "
"Available values: (left|right|center|random).",
)
parser.add_argument(
"--precomputed-features",
type=str2bool,
default=False,
help="Should we pre-compute features and store them on disk or not. "
"It is recommended to disable it for L and XL splits as the "
"pre-computation might currently consume excessive memory and time "
"-- use on-the-fly feature extraction in the training script instead.",
)
return parser
# Similar text filtering and normalization procedure as in:
# https://github.com/SpeechColab/GigaSpeech/blob/main/toolkits/kaldi/gigaspeech_data_prep.sh
def normalize_text(
utt: str,
punct_pattern=re.compile(r"<(COMMA|PERIOD|QUESTIONMARK|EXCLAMATIONPOINT)>"),
whitespace_pattern=re.compile(r"\s\s+"),
) -> str:
return whitespace_pattern.sub(" ", punct_pattern.sub("", utt))
def has_no_oov(
sup: SupervisionSegment,
oov_pattern=re.compile(r"<(SIL|MUSIC|NOISE|OTHER)>"),
) -> bool:
return oov_pattern.search(sup.text) is None
def get_context_suffix(args):
if args.context_window is None or args.context_window <= 0.0:
ctx_suffix = ""
else:
ctx_suffix = f"_{args.context_direction}{args.context_window}"
return ctx_suffix
def compute_fbank_gigaspeech(args):
src_dir = Path("data/manifests") src_dir = Path("data/manifests")
output_dir = Path("data/fbank") output_dir = Path("data/fbank")
num_jobs = min(10, os.cpu_count())
num_mel_bins = 80 num_mel_bins = 80
dataset_parts = ( dataset_parts = (
@ -61,39 +135,114 @@ def compute_fbank_gigaspeech():
assert manifests is not None assert manifests is not None
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins)) extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
ctx_suffix = get_context_suffix(args)
with get_executor() as ex: # Initialize the executor only once. with get_executor() as ex: # Initialize the executor only once.
for partition, m in manifests.items(): for partition, m in manifests.items():
if (output_dir / f"cuts_{partition}.jsonl.gz").is_file(): raw_cuts_path = output_dir / f"cuts_{partition}_raw.jsonl.gz"
logging.info(f"{partition} already exists - skipping.") if raw_cuts_path.is_file():
continue logging.info(
f"{partition} already exists - skipping feature extraction."
)
else:
# Note this step makes the recipe different than LibriSpeech:
# We must filter out some utterances and remove punctuation
# to be consistent with Kaldi.
logging.info("Filtering OOV utterances from supervisions")
m["supervisions"] = m["supervisions"].filter(has_no_oov)
logging.info(f"Normalizing text in {partition}")
for sup in m["supervisions"]:
sup.text = normalize_text(sup.text)
# Create long-recording cut manifests.
logging.info(f"Processing {partition}") logging.info(f"Processing {partition}")
cut_set = CutSet.from_manifests( cut_set = CutSet.from_manifests(
recordings=m["recordings"], recordings=m["recordings"],
supervisions=m["supervisions"], supervisions=m["supervisions"],
) )
if "train" in partition: # Run data augmentation that needs to be done in the
# time domain.
if partition not in ["DEV", "TEST"]:
cut_set = ( cut_set = (
cut_set cut_set
+ cut_set.perturb_speed(0.9) + cut_set.perturb_speed(0.9)
+ cut_set.perturb_speed(1.1) + cut_set.perturb_speed(1.1)
) )
cut_set.to_file(raw_cuts_path)
cuts_path = output_dir / f"cuts_{partition}{ctx_suffix}.jsonl.gz"
if cuts_path.is_file():
logging.info(
f"{partition} already exists - skipping cutting into "
f"sub-segments."
)
else:
try:
# If we skipped initializing `cut_set` because it exists
# on disk, we'll load it. This helps us avoid re-computing
# the features for different variants of context windows.
cut_set
except NameError:
logging.info(f"Reading {partition} raw cuts from disk.")
cut_set = CutSet.from_file(raw_cuts_path)
# Note this step makes the recipe different than LibriSpeech:
# Since recordings are long, the initial CutSet has very long
# cuts with a plenty of supervisions. We cut these into smaller
# chunks centered around each supervision, possibly adding
# acoustic context.
logging.info(
f"About to split {partition} raw cuts into smaller chunks."
)
cut_set = cut_set.trim_to_supervisions(
keep_overlapping=False,
min_duration=None
if args.context_window <= 0.0
else args.context_window,
context_direction=args.context_direction,
)
if partition in ["L", "XL"]:
# Before storing manifests in, we want to pre-shuffle them,
# as the sampler won't be able to do it later in an
# efficient manner.
cut_set = cut_set.shuffle()
if args.precomputed_features:
# Extract the features after cutting large recordings into
# smaller cuts.
# Note:
# we support very efficient "chunked" feature reads with
# the argument `storage_type=ChunkedLilcomHdf5Writer`,
# but we don't support efficient data augmentation and
# feature computation for long recordings yet.
# Therefore, we sacrifice some storage for the ability to
# precompute features on shorter chunks,
# without memory blow-ups.
cut_set = cut_set.compute_and_store_features( cut_set = cut_set.compute_and_store_features(
extractor=extractor, extractor=extractor,
storage_path=f"{output_dir}/feats_{partition}", storage_path=f"{output_dir}/feats_{partition}",
# when an executor is specified, make more partitions # when an executor is specified, make more partitions
num_jobs=num_jobs if ex is None else 80, num_jobs=args.num_jobs if ex is None else 80,
executor=ex, executor=ex,
storage_type=LilcomHdf5Writer, storage_type=LilcomHdf5Writer,
) )
cut_set.to_json(output_dir / f"cuts_{partition}.jsonl.gz") cut_set.to_file(cuts_path)
# Remove cut_set so the next iteration can correctly infer
# whether it needs to load the raw cuts from disk or not.
del cut_set
if __name__ == "__main__": def main():
formatter = ( formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
) )
logging.basicConfig(format=formatter, level=logging.INFO) logging.basicConfig(format=formatter, level=logging.INFO)
compute_fbank_gigaspeech() parser = get_parser()
args = parser.parse_args()
compute_fbank_gigaspeech(args)
if __name__ == "__main__":
main()

View File

@ -110,7 +110,8 @@ fi
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
log "Stage 3: Compute fbank for GigaSpeech" log "Stage 3: Compute fbank for GigaSpeech"
mkdir -p data/fbank mkdir -p data/fbank
./local/compute_fbank_gigaspeech.py ./local/compute_fbank_gigaspeech.py --num-jobs $nj --context-window 0.0 \
--context-direction center --precomputed-features False
fi fi
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then