From 8ca2b2695ef4b449efaf361b9100db90e70bf302 Mon Sep 17 00:00:00 2001 From: yifanyeung Date: Wed, 30 Oct 2024 10:39:05 -0700 Subject: [PATCH] add prepare.sh --- .../TTS/local/extract_speech_tokens.py | 289 ++++++++++++++++++ egs/libriheavy/TTS/local/prepare_manifest.py | 76 +++++ egs/libriheavy/TTS/local/train_bpe_model.py | 1 + egs/libriheavy/TTS/prepare.sh | 131 ++++++++ 4 files changed, 497 insertions(+) create mode 100644 egs/libriheavy/TTS/local/extract_speech_tokens.py create mode 100644 egs/libriheavy/TTS/local/prepare_manifest.py create mode 120000 egs/libriheavy/TTS/local/train_bpe_model.py mode change 100644 => 100755 egs/libriheavy/TTS/prepare.sh diff --git a/egs/libriheavy/TTS/local/extract_speech_tokens.py b/egs/libriheavy/TTS/local/extract_speech_tokens.py new file mode 100644 index 000000000..9b101a482 --- /dev/null +++ b/egs/libriheavy/TTS/local/extract_speech_tokens.py @@ -0,0 +1,289 @@ +#!/usr/bin/env python3 +# Copyright 2024 Xiaomi Corp. (authors: Yifan Yang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import logging +import math +import os +from pathlib import Path +from typing import Optional + +import fairseq +import joblib +import numpy as np +import torch +from lhotse import CutSet, SupervisionSegment +from lhotse.utils import fastcopy +from tqdm import tqdm + +# Torch's multithreaded behavior needs to be disabled or +# it wastes a lot of CPU and slow things down. +# Do this outside of main() in case it needs to take effect +# even when we are not invoking the main (e.g. when spawning subprocesses). +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:512" + + +class ApplyKmeans(object): + def __init__(self, km_path): + self.km_model = joblib.load(km_path) + self.C_np = self.km_model.cluster_centers_.transpose() + self.Cnorm_np = (self.C_np**2).sum(0, keepdims=True) + + self.C = torch.from_numpy(self.C_np) + self.Cnorm = torch.from_numpy(self.Cnorm_np) + if torch.cuda.is_available(): + self.C = self.C.cuda() + self.Cnorm = self.Cnorm.cuda() + + def __call__(self, x): + if isinstance(x, torch.Tensor): + dist = ( + x.pow(2).sum(1, keepdim=True) - 2 * torch.matmul(x, self.C) + self.Cnorm + ) + return dist.argmin(dim=1).cpu().numpy() + else: + dist = ( + (x**2).sum(1, keepdims=True) + - 2 * np.matmul(x, self.C_np) + + self.Cnorm_np + ) + return np.argmin(dist, axis=1) + + +def get_args(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "--subset", + type=str, + default="small", + ) + + parser.add_argument( + "--model-path", + type=str, + default="download/hubert_base_ls960.pt", + ) + + parser.add_argument( + "--kmeans-model-path", + type=str, + default="download/hubert_base_ls960_L9_km500.bin", + ) + + parser.add_argument( + "--start", + type=int, + default=0, + help="Process pieces starting from this number (inclusive).", + ) + + parser.add_argument( + "--stop", + type=int, + default=-1, + help="Stop processing pieces until this number (exclusive).", + ) + + parser.add_argument( + "--window-duration", + type=float, + default=300.0, + ) + + parser.add_argument( + "--shift-duration", + type=float, + default=250.0, + ) + + return parser.parse_args() + + +@torch.no_grad() +def extract_and_save_one_cuts( + raw_cuts_path, + cuts_path, + model, + apply_kmeans, + do_normalize, + window_duration, + shift_duration, +): + logging.info(f"Loading {raw_cuts_path}") + cut_set = CutSet.from_file(raw_cuts_path) + + logging.info("Extracting kmeans") + cuts = [] + + assert window_duration >= shift_duration + window_size = int(window_duration * 16000) + shift_size = int(shift_duration * 16000) + overlap_size = window_size - shift_size + out_overlap_size = get_out_length(overlap_size) + + for cut in tqdm(cut_set): + assert cut.sampling_rate == 16000, f"Sampling rate: {cut.sampling_rate}" + + audio = cut.load_audio() + + T = audio.shape[1] + start = 0 + kmeans = [] + while start < T: + real_window_size = min(window_size, T - start) + audio_window = audio[:, start : start + real_window_size] + + x = ( + torch.from_numpy(audio_window) + .float() + .to(next(model.parameters()).device) + ) + if do_normalize: + x = torch.nn.functional.layer_norm(x, x.shape) + + feature, _ = model.extract_features( + source=x, + padding_mask=None, + mask=False, + output_layer=9, + ) + feature = feature.squeeze(0) + + current_kmeans = apply_kmeans(feature).tolist() + + if start == 0: + kmeans.extend(current_kmeans) + else: + kmeans.extend(current_kmeans[out_overlap_size:]) + + if T - start <= window_size: + break + + start += shift_size + + kmeans = " ".join(map(str, kmeans)) + + cut_with_kmeans = fastcopy( + cut, + custom={"kmeans": kmeans}, + ) + cuts.append(cut_with_kmeans) + + cuts = CutSet(cuts) + + logging.info(f"Saving to {cuts_path}") + cuts.to_file(cuts_path) + + +def extract_kmeans(args): + assert args.subset in ("small", "medium", "large"), f"{args.subset}" + + output_dir = ( + f"data/kmeans/{args.subset}_split" if args.subset != "small" else "data/kmeans" + ) + output_dir = Path(output_dir) + assert output_dir.exists(), f"{output_dir} does not exist!" + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + logging.info(f"device: {device}") + + prefix = "librilight" + + apply_kmeans = ApplyKmeans(args.kmeans_model_path) + model, _, task = fairseq.checkpoint_utils.load_model_ensemble_and_task( + [args.model_path] + ) + model = model[0].eval().to(device) + do_normalize = task.cfg.normalize + + window_duration = args.window_duration + shift_duration = args.shift_duration + + if args.subset == "small": + cuts_path = output_dir / f"{prefix}_cuts_{args.subset}.jsonl.gz" + if cuts_path.is_file(): + logging.info(f"{cuts_path} exists - skipping") + return + + raw_cuts_path = output_dir / f"{prefix}_cuts_{args.subset}_raw.jsonl.gz" + if not raw_cuts_path.is_file(): + logging.info(f"{raw_cuts_path} does not exist - skipping it") + return + + extract_and_save_one_cuts( + raw_cuts_path, + cuts_path, + model, + apply_kmeans, + do_normalize, + window_duration, + shift_duration, + ) + else: + num_digits = 8 # num_digits is fixed by lhotse split-lazy + start = args.start + stop = args.stop + assert stop > start, "stop must be larger than start!" + + for i in range(start, stop): + idx = f"{i}".zfill(num_digits) + logging.info(f"Processing {idx}/{stop - 1}") + + cuts_path = output_dir / f"{prefix}_cuts_{args.subset}.{idx}.jsonl.gz" + if cuts_path.is_file(): + logging.info(f"{cuts_path} exists - skipping") + continue + + raw_cuts_path = ( + output_dir / f"{prefix}_cuts_{args.subset}_raw.{idx}.jsonl.gz" + ) + if not raw_cuts_path.is_file(): + logging.info(f"{raw_cuts_path} does not exist - skipping it") + continue + + extract_and_save_one_cuts( + raw_cuts_path, + cuts_path, + model, + apply_kmeans, + do_normalize, + window_duration, + shift_duration, + ) + + +def get_out_length(T): + conv_layers = [(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512, 2, 2)] * 2 + for i, (out_channels, kernel_size, stride) in enumerate(conv_layers): + T = math.floor((T - kernel_size) / stride) + 1 + + return max(0, T) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + args = get_args() + logging.info(vars(args)) + extract_kmeans(args) diff --git a/egs/libriheavy/TTS/local/prepare_manifest.py b/egs/libriheavy/TTS/local/prepare_manifest.py new file mode 100644 index 000000000..8326e36f5 --- /dev/null +++ b/egs/libriheavy/TTS/local/prepare_manifest.py @@ -0,0 +1,76 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (authors: Wei Kang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gzip +import json +import re +import sys +from pathlib import Path + +from tn.english.normalizer import Normalizer as EnNormalizer + +from icefall.utils import str2bool + + +class TextNormlizer: + def __init__(self): + self.en_tn_model = EnNormalizer() + + def __call__(self, text): + # brackets + # Always text inside brackets with numbers in them. Usually corresponds to "(Sam 23:17)" + text = re.sub(r"\([^\)]*\d[^\)]*\)", " ", text) + if remove_brackets: + text = re.sub(r"\([^\)]*\)", " ", text) + + # Apply mappings + table = str.maketrans("’‘,。;?!():-《》、“”【】", "'',.;?!(): <>/\"\"[]") + text = text.translate(table) + + # Remove extra spaces + text = re.sub(r"\s+", " ", text).strip() + normalized_text = re.sub(r"\s+", " ", normalized_text).strip() + + text = self.en_tn_model.normalize(text) + return text.strip() + + +# Assign text of the supervisions and remove unnecessary entries. +def main(): + assert ( + len(sys.argv) == 4 + ), "Usage: ./local/prepare_manifest.py INPUT OUTPUT_DIR KEEP_CUSTOM_FIELDS" + fname = Path(sys.argv[1]).name + oname = Path(sys.argv[2]) / fname + keep_custom_fields = str2bool(sys.argv[3]) + + tn = TextNormlizer() + + with gzip.open(sys.argv[1], "r") as fin, gzip.open(oname, "w") as fout: + for line in fin: + cut = json.loads(line) + cut["supervisions"][0]["text"] = tn( + cut["supervisions"][0]["custom"]["texts"][0] + ) + if not keep_custom_fields: + del cut["supervisions"][0]["custom"] + del cut["custom"] + fout.write((json.dumps(cut) + "\n").encode()) + + +if __name__ == "__main__": + main() diff --git a/egs/libriheavy/TTS/local/train_bpe_model.py b/egs/libriheavy/TTS/local/train_bpe_model.py new file mode 120000 index 000000000..bbb62cced --- /dev/null +++ b/egs/libriheavy/TTS/local/train_bpe_model.py @@ -0,0 +1 @@ +../../../libriheavy/ASR/local/train_bpe_model.py \ No newline at end of file diff --git a/egs/libriheavy/TTS/prepare.sh b/egs/libriheavy/TTS/prepare.sh old mode 100644 new mode 100755 index e69de29bb..00cefbb2d --- a/egs/libriheavy/TTS/prepare.sh +++ b/egs/libriheavy/TTS/prepare.sh @@ -0,0 +1,131 @@ +#!/usr/bin/env bash + +# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674 +export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python + +set -eou pipefail + +nj=15 +stage=-1 +stop_stage=100 + +# We assume dl_dir (download dir) contains the following +# directories and files. If not, they will be downloaded +# by this script automatically. +# +# - $dl_dir/librilight +# You can find small, medium, large, etc. inside it. +# +# - $dl_dir/libriheavy +# You can find libriheavy_cuts_small.jsonl.gz, libriheavy_cuts_medium.jsonl.gz, etc. inside it. +dl_dir=$PWD/download + +. shared/parse_options.sh || exit 1 + +# vocab size for sentence piece models. +# It will generate data/lang_bpe_xxx, +# data/lang_bpe_yyy if the array contains xxx, yyy +vocab_sizes=( + 4000 +) + +# All files generated by this script are saved in "data". +# You can safely remove "data" and rerun this script to regenerate it. +mkdir -p data +tokens_dir=data/tokens +manifests_dir=data/manifests + +log() { + # This function is from espnet + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} + +log "dl_dir: $dl_dir" + +if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then + log "Stage -1: Download audio data." + # If you have pre-downloaded it to /path/to/librilight, + # you can create a symlink + # + # ln -sfv /path/to/librilight $dl_dir/librilight + # + mkdir -p $dl_dir/librilight + for subset in small medium large; do + log "Downloading ${subset} subset." + if [ ! -d $dl_dir/librilight/${subset} ]; then + wget -P $dl_dir/librilight -c https://dl.fbaipublicfiles.com/librilight/data/${subset}.tar + tar xf $dl_dir/librilight/${subset}.tar -C $dl_dir/librilight + else + log "Skipping download, ${subset} subset exists." + fi + done +fi + +if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then + log "Stage 0: Download manifests from huggingface." + + # If you have pre-downloaded it to /path/to/libriheavy, + # you can create a symlink + # + # ln -sfv /path/to/libriheavy $dl_dir/libriheavy + # + mkdir -p $dl_dir/libriheavy + for subset in small medium large dev test_clean test_other; do + if [ ! -e $dl_dir/libriheavy/libriheavy_cuts_${subset}.jsonl.gz ]; then + log "Downloading ${subset} subset." + wget -P $dl_dir/libriheavy -c https://huggingface.co/datasets/pkufool/libriheavy/resolve/main/libriheavy_cuts_${subset}.jsonl.gz + else + log "Skipping download, ${subset} subset exists." + fi + done +fi + +if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then + log "Stage 3: Prepare Libriheavy manifests" + mkdir -p $manifests_dir + for subset in small medium large dev test_clean test_other; do + if [ ! -e $manifests_dir/libriheavy_cuts_${subset}.jsonl.gz ]; then + log "Prepare manifest for subset : ${subset}" + ./local/prepare_manifest.py $dl_dir/libriheavy/libriheavy_cuts_${subset}.jsonl.gz $manifests_dir False + fi + done +fi + +num_per_split=200000 +if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then + log "Stage 6: Split medium and large subsets." + for subset in medium large; do + log "Spliting subset : $subset" + split_dir=$manifests_dir/libriheavy_${subset}_split + mkdir -p $split_dir + if [ ! -e $split_dir/.split_completed ]; then + lhotse split-lazy $manifests_dir/libriheavy_cuts_${subset}.jsonl.gz $split_dir $num_per_split + touch $split_dir/.split_completed + fi + done +fi + +if [ $stage -le 10 ] && [ $stop_stage -ge 10 ]; then + log "Stage 10: Train BPE model for unnormalized text" + if [ ! -f data/punc_texts ]; then + gunzip -c $manifests_dir/libriheavy_cuts_medium.jsonl.gz \ + | jq '.supervisions[].text' | sed 's/"//;s/\\//g;s/"$//' > data/punc_texts + fi + for vocab_size in ${vocab_sizes[@]}; do + lang_dir=data/lang_punc_bpe_${vocab_size} + mkdir -p $lang_dir + + cp data/punc_texts $lang_dir/text + + if [ ! -f $lang_dir/bpe.model ]; then + ./local/train_bpe_model.py \ + --lang-dir $lang_dir \ + --byte-fallback \ + --vocab-size ${vocab_size} \ + --byte-fallback \ + --character-coverage 0.99 \ + --transcript $lang_dir/text + fi + done +fi