mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-09 17:14:20 +00:00
updated
This commit is contained in:
parent
0350d66f73
commit
ff9d484e71
178
egs/multi_zh-hans/ASR/local/compute_fbank_kespeech_splits.py
Normal file
178
egs/multi_zh-hans/ASR/local/compute_fbank_kespeech_splits.py
Normal file
@ -0,0 +1,178 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2021 Johns Hopkins University (Piotr Żelasko)
|
||||
# Copyright 2021 Xiaomi Corp. (Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from lhotse import (
|
||||
CutSet,
|
||||
KaldifeatFbank,
|
||||
KaldifeatFbankConfig,
|
||||
LilcomChunkyWriter,
|
||||
set_audio_duration_mismatch_tolerance,
|
||||
set_caching_enabled,
|
||||
)
|
||||
|
||||
# Torch's multithreaded behavior needs to be disabled or
|
||||
# it wastes a lot of CPU and slow things down.
|
||||
# Do this outside of main() in case it needs to take effect
|
||||
# even when we are not invoking the main (e.g. when spawning subprocesses).
|
||||
torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--training-subset",
|
||||
type=str,
|
||||
default="L",
|
||||
help="The training subset for computing fbank feature.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num-workers",
|
||||
type=int,
|
||||
default=20,
|
||||
help="Number of dataloading workers used for reading the audio.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--batch-duration",
|
||||
type=float,
|
||||
default=600.0,
|
||||
help="The maximum number of audio seconds in a batch."
|
||||
"Determines batch size dynamically.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num-splits",
|
||||
type=int,
|
||||
required=True,
|
||||
help="The number of splits of the L subset",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--start",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Process pieces starting from this number (inclusive).",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--stop",
|
||||
type=int,
|
||||
default=-1,
|
||||
help="Stop processing pieces until this number (exclusive).",
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
def compute_fbank_kespeech_splits(args):
|
||||
subset = args.training_subset
|
||||
subset = str(subset)
|
||||
num_splits = args.num_splits
|
||||
output_dir = f"data/fbank/{subset}_split_{num_splits}"
|
||||
output_dir = Path(output_dir)
|
||||
assert output_dir.exists(), f"{output_dir} does not exist!"
|
||||
|
||||
num_digits = len(str(num_splits))
|
||||
|
||||
start = args.start
|
||||
stop = args.stop
|
||||
if stop < start:
|
||||
stop = num_splits
|
||||
|
||||
stop = min(stop, num_splits)
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", 0)
|
||||
extractor = KaldifeatFbank(KaldifeatFbankConfig(device=device))
|
||||
logging.info(f"device: {device}")
|
||||
|
||||
set_audio_duration_mismatch_tolerance(0.01) # 10ms tolerance
|
||||
set_caching_enabled(False)
|
||||
for i in range(start, stop):
|
||||
idx = f"{i + 1}".zfill(num_digits)
|
||||
logging.info(f"Processing {idx}/{num_splits}")
|
||||
|
||||
cuts_path = output_dir / f"cuts_{subset}.{idx}.jsonl.gz"
|
||||
if cuts_path.is_file():
|
||||
logging.info(f"{cuts_path} exists - skipping")
|
||||
continue
|
||||
|
||||
raw_cuts_path = output_dir / f"cuts_{subset}_raw.{idx}.jsonl.gz"
|
||||
|
||||
logging.info(f"Loading {raw_cuts_path}")
|
||||
cut_set = CutSet.from_file(raw_cuts_path)
|
||||
|
||||
logging.info("Splitting cuts into smaller chunks.")
|
||||
cut_set = cut_set.trim_to_supervisions(
|
||||
keep_overlapping=False, min_duration=None
|
||||
)
|
||||
|
||||
logging.info("Computing features")
|
||||
cut_set = cut_set.compute_and_store_features_batch(
|
||||
extractor=extractor,
|
||||
storage_path=f"{output_dir}/feats_{subset}_{idx}",
|
||||
num_workers=args.num_workers,
|
||||
batch_duration=args.batch_duration,
|
||||
storage_type=LilcomChunkyWriter,
|
||||
overwrite=True,
|
||||
)
|
||||
|
||||
logging.info(f"Saving to {cuts_path}")
|
||||
cut_set.to_file(cuts_path)
|
||||
|
||||
|
||||
def main():
|
||||
now = datetime.now()
|
||||
date_time = now.strftime("%Y-%m-%d-%H-%M-%S")
|
||||
|
||||
log_filename = "log-compute_fbank_kespeech_splits"
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
log_filename = f"{log_filename}-{date_time}"
|
||||
|
||||
logging.basicConfig(
|
||||
filename=log_filename,
|
||||
format=formatter,
|
||||
level=logging.INFO,
|
||||
filemode="w",
|
||||
)
|
||||
|
||||
console = logging.StreamHandler()
|
||||
console.setLevel(logging.INFO)
|
||||
console.setFormatter(logging.Formatter(formatter))
|
||||
logging.getLogger("").addHandler(console)
|
||||
|
||||
parser = get_parser()
|
||||
args = parser.parse_args()
|
||||
logging.info(vars(args))
|
||||
|
||||
compute_fbank_kespeech_splits(args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
134
egs/multi_zh-hans/ASR/local/preprocess_kespeech.py
Normal file
134
egs/multi_zh-hans/ASR/local/preprocess_kespeech.py
Normal file
@ -0,0 +1,134 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2021 Johns Hopkins University (Piotr Żelasko)
|
||||
# Copyright 2021 Xiaomi Corp. (Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import re
|
||||
from pathlib import Path
|
||||
|
||||
from lhotse import CutSet, SupervisionSegment
|
||||
from lhotse.recipes.utils import read_manifests_if_cached
|
||||
|
||||
from icefall import setup_logger
|
||||
|
||||
# Similar text filtering and normalization procedure as in:
|
||||
# https://github.com/SpeechColab/WenetSpeech/blob/main/toolkits/kaldi/wenetspeech_data_prep.sh
|
||||
|
||||
|
||||
def normalize_text(
|
||||
utt: str,
|
||||
# punct_pattern=re.compile(r"<(COMMA|PERIOD|QUESTIONMARK|EXCLAMATIONPOINT)>"),
|
||||
punct_pattern=re.compile(r"<(PERIOD|QUESTIONMARK|EXCLAMATIONPOINT)>"),
|
||||
whitespace_pattern=re.compile(r"\s\s+"),
|
||||
) -> str:
|
||||
return whitespace_pattern.sub(" ", punct_pattern.sub("", utt))
|
||||
|
||||
|
||||
def has_no_oov(
|
||||
sup: SupervisionSegment,
|
||||
oov_pattern=re.compile(r"<(SIL|MUSIC|NOISE|OTHER)>"),
|
||||
) -> bool:
|
||||
return oov_pattern.search(sup.text) is None
|
||||
|
||||
|
||||
def preprocess_kespeech():
|
||||
src_dir = Path("data/manifests/KeSpeech")
|
||||
output_dir = Path("data/fbank")
|
||||
output_dir.mkdir(exist_ok=True)
|
||||
|
||||
# Note: By default, we preprocess all sub-parts.
|
||||
# You can delete those that you don't need.
|
||||
# For instance, if you don't want to use the L subpart, just remove
|
||||
# the line below containing "L"
|
||||
dataset_parts = (
|
||||
"dev_phase1",
|
||||
"dev_phase2",
|
||||
"test",
|
||||
"train_phase1",
|
||||
"train_phase2",
|
||||
)
|
||||
|
||||
logging.info("Loading manifest (may take 10 minutes)")
|
||||
manifests = read_manifests_if_cached(
|
||||
dataset_parts=dataset_parts,
|
||||
output_dir=src_dir,
|
||||
suffix="jsonl.gz",
|
||||
prefix="kespeech-asr",
|
||||
)
|
||||
assert manifests is not None
|
||||
|
||||
assert len(manifests) == len(dataset_parts), (
|
||||
len(manifests),
|
||||
len(dataset_parts),
|
||||
list(manifests.keys()),
|
||||
dataset_parts,
|
||||
)
|
||||
|
||||
for partition, m in manifests.items():
|
||||
logging.info(f"Processing {partition}")
|
||||
raw_cuts_path = output_dir / f"cuts_{partition}_raw.jsonl.gz"
|
||||
if raw_cuts_path.is_file():
|
||||
logging.info(f"{partition} already exists - skipping")
|
||||
continue
|
||||
|
||||
# Note this step makes the recipe different than LibriSpeech:
|
||||
# We must filter out some utterances and remove punctuation
|
||||
# to be consistent with Kaldi.
|
||||
logging.info("Filtering OOV utterances from supervisions")
|
||||
m["supervisions"] = m["supervisions"].filter(has_no_oov)
|
||||
logging.info(f"Normalizing text in {partition}")
|
||||
for sup in m["supervisions"]:
|
||||
text = str(sup.text)
|
||||
orig_text = text
|
||||
sup.text = normalize_text(sup.text)
|
||||
text = str(sup.text)
|
||||
if len(orig_text) != len(text):
|
||||
logging.info(
|
||||
f"\nOriginal text vs normalized text:\n{orig_text}\n{text}"
|
||||
)
|
||||
|
||||
# Create long-recording cut manifests.
|
||||
logging.info(f"Processing {partition}")
|
||||
cut_set = CutSet.from_manifests(
|
||||
recordings=m["recordings"],
|
||||
supervisions=m["supervisions"],
|
||||
)
|
||||
# Run data augmentation that needs to be done in the
|
||||
# time domain.
|
||||
if partition not in [
|
||||
"dev_phase1",
|
||||
"dev_phase2",
|
||||
"test",
|
||||
]:
|
||||
logging.info(
|
||||
f"Speed perturb for {partition} with factors 0.9 and 1.1 "
|
||||
"(Perturbing may take 8 minutes and saving may take 20 minutes)"
|
||||
)
|
||||
cut_set = cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
|
||||
logging.info(f"Saving to {raw_cuts_path}")
|
||||
cut_set.to_file(raw_cuts_path)
|
||||
|
||||
|
||||
def main():
|
||||
setup_logger(log_filename="./log-preprocess-kespeech")
|
||||
|
||||
preprocess_kespeech()
|
||||
logging.info("Done")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
Loading…
x
Reference in New Issue
Block a user