From fdc04708603d0eb53f2de659f8a2f08b4cb7a28c Mon Sep 17 00:00:00 2001 From: yifanyeung Date: Sat, 2 Nov 2024 22:44:43 -0700 Subject: [PATCH] add prepare.sh --- .../TTS/local/extract_speech_tokens.py | 180 +++--------------- egs/libriheavy/TTS/local/norm_text.py | 1 + egs/libriheavy/TTS/prepare.sh | 30 +-- 3 files changed, 40 insertions(+), 171 deletions(-) create mode 120000 egs/libriheavy/TTS/local/norm_text.py diff --git a/egs/libriheavy/TTS/local/extract_speech_tokens.py b/egs/libriheavy/TTS/local/extract_speech_tokens.py index 9b101a482..1ae80f953 100644 --- a/egs/libriheavy/TTS/local/extract_speech_tokens.py +++ b/egs/libriheavy/TTS/local/extract_speech_tokens.py @@ -17,55 +17,15 @@ import argparse import logging -import math import os from pathlib import Path from typing import Optional -import fairseq -import joblib -import numpy as np import torch from lhotse import CutSet, SupervisionSegment from lhotse.utils import fastcopy from tqdm import tqdm -# Torch's multithreaded behavior needs to be disabled or -# it wastes a lot of CPU and slow things down. -# Do this outside of main() in case it needs to take effect -# even when we are not invoking the main (e.g. when spawning subprocesses). -torch.set_num_threads(1) -torch.set_num_interop_threads(1) - -os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:512" - - -class ApplyKmeans(object): - def __init__(self, km_path): - self.km_model = joblib.load(km_path) - self.C_np = self.km_model.cluster_centers_.transpose() - self.Cnorm_np = (self.C_np**2).sum(0, keepdims=True) - - self.C = torch.from_numpy(self.C_np) - self.Cnorm = torch.from_numpy(self.Cnorm_np) - if torch.cuda.is_available(): - self.C = self.C.cuda() - self.Cnorm = self.Cnorm.cuda() - - def __call__(self, x): - if isinstance(x, torch.Tensor): - dist = ( - x.pow(2).sum(1, keepdim=True) - 2 * torch.matmul(x, self.C) + self.Cnorm - ) - return dist.argmin(dim=1).cpu().numpy() - else: - dist = ( - (x**2).sum(1, keepdims=True) - - 2 * np.matmul(x, self.C_np) - + self.Cnorm_np - ) - return np.argmin(dist, axis=1) - def get_args(): parser = argparse.ArgumentParser() @@ -82,12 +42,6 @@ def get_args(): default="download/hubert_base_ls960.pt", ) - parser.add_argument( - "--kmeans-model-path", - type=str, - default="download/hubert_base_ls960_L9_km500.bin", - ) - parser.add_argument( "--start", type=int, @@ -102,90 +56,27 @@ def get_args(): help="Stop processing pieces until this number (exclusive).", ) - parser.add_argument( - "--window-duration", - type=float, - default=300.0, - ) - - parser.add_argument( - "--shift-duration", - type=float, - default=250.0, - ) - return parser.parse_args() @torch.no_grad() def extract_and_save_one_cuts( - raw_cuts_path, + manifests_path, cuts_path, - model, - apply_kmeans, - do_normalize, - window_duration, - shift_duration, ): - logging.info(f"Loading {raw_cuts_path}") - cut_set = CutSet.from_file(raw_cuts_path) + logging.info(f"Loading {manifests_path}") + cut_set = CutSet.from_file(manifests_path) - logging.info("Extracting kmeans") + logging.info("Extracting tokens") cuts = [] - assert window_duration >= shift_duration - window_size = int(window_duration * 16000) - shift_size = int(shift_duration * 16000) - overlap_size = window_size - shift_size - out_overlap_size = get_out_length(overlap_size) + tokens = " ".join(map(str, tokens)) - for cut in tqdm(cut_set): - assert cut.sampling_rate == 16000, f"Sampling rate: {cut.sampling_rate}" - - audio = cut.load_audio() - - T = audio.shape[1] - start = 0 - kmeans = [] - while start < T: - real_window_size = min(window_size, T - start) - audio_window = audio[:, start : start + real_window_size] - - x = ( - torch.from_numpy(audio_window) - .float() - .to(next(model.parameters()).device) - ) - if do_normalize: - x = torch.nn.functional.layer_norm(x, x.shape) - - feature, _ = model.extract_features( - source=x, - padding_mask=None, - mask=False, - output_layer=9, - ) - feature = feature.squeeze(0) - - current_kmeans = apply_kmeans(feature).tolist() - - if start == 0: - kmeans.extend(current_kmeans) - else: - kmeans.extend(current_kmeans[out_overlap_size:]) - - if T - start <= window_size: - break - - start += shift_size - - kmeans = " ".join(map(str, kmeans)) - - cut_with_kmeans = fastcopy( - cut, - custom={"kmeans": kmeans}, - ) - cuts.append(cut_with_kmeans) + cut_with_tokens = fastcopy( + cut, + custom={"tokens": tokens}, + ) + cuts.append(cut_with_tokens) cuts = CutSet(cuts) @@ -193,11 +84,11 @@ def extract_and_save_one_cuts( cuts.to_file(cuts_path) -def extract_kmeans(args): +def extract_speech_tokens(args): assert args.subset in ("small", "medium", "large"), f"{args.subset}" output_dir = ( - f"data/kmeans/{args.subset}_split" if args.subset != "small" else "data/kmeans" + f"data/tokens/{args.subset}_split" if args.subset != "small" else "data/tokens" ) output_dir = Path(output_dir) assert output_dir.exists(), f"{output_dir} does not exist!" @@ -207,17 +98,7 @@ def extract_kmeans(args): device = torch.device("cuda", 0) logging.info(f"device: {device}") - prefix = "librilight" - - apply_kmeans = ApplyKmeans(args.kmeans_model_path) - model, _, task = fairseq.checkpoint_utils.load_model_ensemble_and_task( - [args.model_path] - ) - model = model[0].eval().to(device) - do_normalize = task.cfg.normalize - - window_duration = args.window_duration - shift_duration = args.shift_duration + prefix = "libriheavy" if args.subset == "small": cuts_path = output_dir / f"{prefix}_cuts_{args.subset}.jsonl.gz" @@ -225,16 +106,16 @@ def extract_kmeans(args): logging.info(f"{cuts_path} exists - skipping") return - raw_cuts_path = output_dir / f"{prefix}_cuts_{args.subset}_raw.jsonl.gz" - if not raw_cuts_path.is_file(): - logging.info(f"{raw_cuts_path} does not exist - skipping it") + manifests_path = output_dir / f"{prefix}_cuts_{args.subset}.jsonl.gz" + if not manifests_path.is_file(): + logging.info(f"{manifests_path} does not exist - skipping it") return extract_and_save_one_cuts( - raw_cuts_path, + manifests_path, cuts_path, model, - apply_kmeans, + apply_tokens, do_normalize, window_duration, shift_duration, @@ -254,36 +135,23 @@ def extract_kmeans(args): logging.info(f"{cuts_path} exists - skipping") continue - raw_cuts_path = ( - output_dir / f"{prefix}_cuts_{args.subset}_raw.{idx}.jsonl.gz" + manifests_path = ( + output_dir / f"{prefix}_cuts_{args.subset}.{idx}.jsonl.gz" ) - if not raw_cuts_path.is_file(): - logging.info(f"{raw_cuts_path} does not exist - skipping it") + if not manifests_path.is_file(): + logging.info(f"{manifests_path} does not exist - skipping it") continue extract_and_save_one_cuts( - raw_cuts_path, + manifests_path, cuts_path, - model, - apply_kmeans, - do_normalize, - window_duration, - shift_duration, ) -def get_out_length(T): - conv_layers = [(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512, 2, 2)] * 2 - for i, (out_channels, kernel_size, stride) in enumerate(conv_layers): - T = math.floor((T - kernel_size) / stride) + 1 - - return max(0, T) - - if __name__ == "__main__": formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) args = get_args() logging.info(vars(args)) - extract_kmeans(args) + extract_speech_tokens(args) diff --git a/egs/libriheavy/TTS/local/norm_text.py b/egs/libriheavy/TTS/local/norm_text.py new file mode 120000 index 000000000..41dce0944 --- /dev/null +++ b/egs/libriheavy/TTS/local/norm_text.py @@ -0,0 +1 @@ +../../ASR/local/norm_text.py \ No newline at end of file diff --git a/egs/libriheavy/TTS/prepare.sh b/egs/libriheavy/TTS/prepare.sh index 208c01024..d2fe27928 100755 --- a/egs/libriheavy/TTS/prepare.sh +++ b/egs/libriheavy/TTS/prepare.sh @@ -81,8 +81,8 @@ if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then done fi -if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then - log "Stage 3: Prepare Libriheavy manifests" +if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then + log "Stage 1: Prepare Libriheavy manifests" mkdir -p $manifests_dir for subset in small medium large dev test_clean test_other; do if [ ! -e $manifests_dir/libriheavy_cuts_${subset}.jsonl.gz ]; then @@ -93,8 +93,8 @@ if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then fi num_per_split=200000 -if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then - log "Stage 6: Split medium and large subsets." +if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then + log "Stage 2: Split medium and large subsets." for subset in medium large; do log "Spliting subset : $subset" split_dir=$manifests_dir/libriheavy_${subset}_split @@ -106,26 +106,26 @@ if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then done fi -if [ $stage -le 10 ] && [ $stop_stage -ge 10 ]; then - log "Stage 10: Train BPE model for unnormalized text" - if [ ! -f data/punc_texts ]; then +if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then + log "Stage 3: Train BPE model for normalized text" + + if [ ! -f data/texts ]; then gunzip -c $manifests_dir/libriheavy_cuts_medium.jsonl.gz \ - | jq '.supervisions[].text' | sed 's/"//;s/\\//g;s/"$//' > data/punc_texts + | jq '.supervisions[].text' | sed 's/"//;s/\\//g;s/"$//' \ + | ./local/norm_text.py > data/texts fi + for vocab_size in ${vocab_sizes[@]}; do - lang_dir=data/lang_punc_bpe_${vocab_size} + lang_dir=data/lang_bpe_${vocab_size} mkdir -p $lang_dir - cp data/punc_texts $lang_dir/text + cp data/texts $lang_dir/text if [ ! -f $lang_dir/bpe.model ]; then ./local/train_bpe_model.py \ --lang-dir $lang_dir \ - --byte-fallback \ - --vocab-size ${vocab_size} \ - --byte-fallback \ - --character-coverage 0.99 \ + --vocab-size $vocab_size \ --transcript $lang_dir/text fi done -fi +fi \ No newline at end of file