mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-13 12:02:21 +00:00
add prepare.sh
This commit is contained in:
parent
23137c2987
commit
8ca2b2695e
289
egs/libriheavy/TTS/local/extract_speech_tokens.py
Normal file
289
egs/libriheavy/TTS/local/extract_speech_tokens.py
Normal file
@ -0,0 +1,289 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2024 Xiaomi Corp. (authors: Yifan Yang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import fairseq
|
||||
import joblib
|
||||
import numpy as np
|
||||
import torch
|
||||
from lhotse import CutSet, SupervisionSegment
|
||||
from lhotse.utils import fastcopy
|
||||
from tqdm import tqdm
|
||||
|
||||
# Torch's multithreaded behavior needs to be disabled or
|
||||
# it wastes a lot of CPU and slow things down.
|
||||
# Do this outside of main() in case it needs to take effect
|
||||
# even when we are not invoking the main (e.g. when spawning subprocesses).
|
||||
torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
|
||||
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:512"
|
||||
|
||||
|
||||
class ApplyKmeans(object):
|
||||
def __init__(self, km_path):
|
||||
self.km_model = joblib.load(km_path)
|
||||
self.C_np = self.km_model.cluster_centers_.transpose()
|
||||
self.Cnorm_np = (self.C_np**2).sum(0, keepdims=True)
|
||||
|
||||
self.C = torch.from_numpy(self.C_np)
|
||||
self.Cnorm = torch.from_numpy(self.Cnorm_np)
|
||||
if torch.cuda.is_available():
|
||||
self.C = self.C.cuda()
|
||||
self.Cnorm = self.Cnorm.cuda()
|
||||
|
||||
def __call__(self, x):
|
||||
if isinstance(x, torch.Tensor):
|
||||
dist = (
|
||||
x.pow(2).sum(1, keepdim=True) - 2 * torch.matmul(x, self.C) + self.Cnorm
|
||||
)
|
||||
return dist.argmin(dim=1).cpu().numpy()
|
||||
else:
|
||||
dist = (
|
||||
(x**2).sum(1, keepdims=True)
|
||||
- 2 * np.matmul(x, self.C_np)
|
||||
+ self.Cnorm_np
|
||||
)
|
||||
return np.argmin(dist, axis=1)
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
"--subset",
|
||||
type=str,
|
||||
default="small",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--model-path",
|
||||
type=str,
|
||||
default="download/hubert_base_ls960.pt",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--kmeans-model-path",
|
||||
type=str,
|
||||
default="download/hubert_base_ls960_L9_km500.bin",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--start",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Process pieces starting from this number (inclusive).",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--stop",
|
||||
type=int,
|
||||
default=-1,
|
||||
help="Stop processing pieces until this number (exclusive).",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--window-duration",
|
||||
type=float,
|
||||
default=300.0,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--shift-duration",
|
||||
type=float,
|
||||
default=250.0,
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def extract_and_save_one_cuts(
|
||||
raw_cuts_path,
|
||||
cuts_path,
|
||||
model,
|
||||
apply_kmeans,
|
||||
do_normalize,
|
||||
window_duration,
|
||||
shift_duration,
|
||||
):
|
||||
logging.info(f"Loading {raw_cuts_path}")
|
||||
cut_set = CutSet.from_file(raw_cuts_path)
|
||||
|
||||
logging.info("Extracting kmeans")
|
||||
cuts = []
|
||||
|
||||
assert window_duration >= shift_duration
|
||||
window_size = int(window_duration * 16000)
|
||||
shift_size = int(shift_duration * 16000)
|
||||
overlap_size = window_size - shift_size
|
||||
out_overlap_size = get_out_length(overlap_size)
|
||||
|
||||
for cut in tqdm(cut_set):
|
||||
assert cut.sampling_rate == 16000, f"Sampling rate: {cut.sampling_rate}"
|
||||
|
||||
audio = cut.load_audio()
|
||||
|
||||
T = audio.shape[1]
|
||||
start = 0
|
||||
kmeans = []
|
||||
while start < T:
|
||||
real_window_size = min(window_size, T - start)
|
||||
audio_window = audio[:, start : start + real_window_size]
|
||||
|
||||
x = (
|
||||
torch.from_numpy(audio_window)
|
||||
.float()
|
||||
.to(next(model.parameters()).device)
|
||||
)
|
||||
if do_normalize:
|
||||
x = torch.nn.functional.layer_norm(x, x.shape)
|
||||
|
||||
feature, _ = model.extract_features(
|
||||
source=x,
|
||||
padding_mask=None,
|
||||
mask=False,
|
||||
output_layer=9,
|
||||
)
|
||||
feature = feature.squeeze(0)
|
||||
|
||||
current_kmeans = apply_kmeans(feature).tolist()
|
||||
|
||||
if start == 0:
|
||||
kmeans.extend(current_kmeans)
|
||||
else:
|
||||
kmeans.extend(current_kmeans[out_overlap_size:])
|
||||
|
||||
if T - start <= window_size:
|
||||
break
|
||||
|
||||
start += shift_size
|
||||
|
||||
kmeans = " ".join(map(str, kmeans))
|
||||
|
||||
cut_with_kmeans = fastcopy(
|
||||
cut,
|
||||
custom={"kmeans": kmeans},
|
||||
)
|
||||
cuts.append(cut_with_kmeans)
|
||||
|
||||
cuts = CutSet(cuts)
|
||||
|
||||
logging.info(f"Saving to {cuts_path}")
|
||||
cuts.to_file(cuts_path)
|
||||
|
||||
|
||||
def extract_kmeans(args):
|
||||
assert args.subset in ("small", "medium", "large"), f"{args.subset}"
|
||||
|
||||
output_dir = (
|
||||
f"data/kmeans/{args.subset}_split" if args.subset != "small" else "data/kmeans"
|
||||
)
|
||||
output_dir = Path(output_dir)
|
||||
assert output_dir.exists(), f"{output_dir} does not exist!"
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", 0)
|
||||
logging.info(f"device: {device}")
|
||||
|
||||
prefix = "librilight"
|
||||
|
||||
apply_kmeans = ApplyKmeans(args.kmeans_model_path)
|
||||
model, _, task = fairseq.checkpoint_utils.load_model_ensemble_and_task(
|
||||
[args.model_path]
|
||||
)
|
||||
model = model[0].eval().to(device)
|
||||
do_normalize = task.cfg.normalize
|
||||
|
||||
window_duration = args.window_duration
|
||||
shift_duration = args.shift_duration
|
||||
|
||||
if args.subset == "small":
|
||||
cuts_path = output_dir / f"{prefix}_cuts_{args.subset}.jsonl.gz"
|
||||
if cuts_path.is_file():
|
||||
logging.info(f"{cuts_path} exists - skipping")
|
||||
return
|
||||
|
||||
raw_cuts_path = output_dir / f"{prefix}_cuts_{args.subset}_raw.jsonl.gz"
|
||||
if not raw_cuts_path.is_file():
|
||||
logging.info(f"{raw_cuts_path} does not exist - skipping it")
|
||||
return
|
||||
|
||||
extract_and_save_one_cuts(
|
||||
raw_cuts_path,
|
||||
cuts_path,
|
||||
model,
|
||||
apply_kmeans,
|
||||
do_normalize,
|
||||
window_duration,
|
||||
shift_duration,
|
||||
)
|
||||
else:
|
||||
num_digits = 8 # num_digits is fixed by lhotse split-lazy
|
||||
start = args.start
|
||||
stop = args.stop
|
||||
assert stop > start, "stop must be larger than start!"
|
||||
|
||||
for i in range(start, stop):
|
||||
idx = f"{i}".zfill(num_digits)
|
||||
logging.info(f"Processing {idx}/{stop - 1}")
|
||||
|
||||
cuts_path = output_dir / f"{prefix}_cuts_{args.subset}.{idx}.jsonl.gz"
|
||||
if cuts_path.is_file():
|
||||
logging.info(f"{cuts_path} exists - skipping")
|
||||
continue
|
||||
|
||||
raw_cuts_path = (
|
||||
output_dir / f"{prefix}_cuts_{args.subset}_raw.{idx}.jsonl.gz"
|
||||
)
|
||||
if not raw_cuts_path.is_file():
|
||||
logging.info(f"{raw_cuts_path} does not exist - skipping it")
|
||||
continue
|
||||
|
||||
extract_and_save_one_cuts(
|
||||
raw_cuts_path,
|
||||
cuts_path,
|
||||
model,
|
||||
apply_kmeans,
|
||||
do_normalize,
|
||||
window_duration,
|
||||
shift_duration,
|
||||
)
|
||||
|
||||
|
||||
def get_out_length(T):
|
||||
conv_layers = [(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512, 2, 2)] * 2
|
||||
for i, (out_channels, kernel_size, stride) in enumerate(conv_layers):
|
||||
T = math.floor((T - kernel_size) / stride) + 1
|
||||
|
||||
return max(0, T)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
args = get_args()
|
||||
logging.info(vars(args))
|
||||
extract_kmeans(args)
|
76
egs/libriheavy/TTS/local/prepare_manifest.py
Normal file
76
egs/libriheavy/TTS/local/prepare_manifest.py
Normal file
@ -0,0 +1,76 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2023 Xiaomi Corp. (authors: Wei Kang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import gzip
|
||||
import json
|
||||
import re
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
from tn.english.normalizer import Normalizer as EnNormalizer
|
||||
|
||||
from icefall.utils import str2bool
|
||||
|
||||
|
||||
class TextNormlizer:
|
||||
def __init__(self):
|
||||
self.en_tn_model = EnNormalizer()
|
||||
|
||||
def __call__(self, text):
|
||||
# brackets
|
||||
# Always text inside brackets with numbers in them. Usually corresponds to "(Sam 23:17)"
|
||||
text = re.sub(r"\([^\)]*\d[^\)]*\)", " ", text)
|
||||
if remove_brackets:
|
||||
text = re.sub(r"\([^\)]*\)", " ", text)
|
||||
|
||||
# Apply mappings
|
||||
table = str.maketrans("’‘,。;?!():-《》、“”【】", "'',.;?!(): <>/\"\"[]")
|
||||
text = text.translate(table)
|
||||
|
||||
# Remove extra spaces
|
||||
text = re.sub(r"\s+", " ", text).strip()
|
||||
normalized_text = re.sub(r"\s+", " ", normalized_text).strip()
|
||||
|
||||
text = self.en_tn_model.normalize(text)
|
||||
return text.strip()
|
||||
|
||||
|
||||
# Assign text of the supervisions and remove unnecessary entries.
|
||||
def main():
|
||||
assert (
|
||||
len(sys.argv) == 4
|
||||
), "Usage: ./local/prepare_manifest.py INPUT OUTPUT_DIR KEEP_CUSTOM_FIELDS"
|
||||
fname = Path(sys.argv[1]).name
|
||||
oname = Path(sys.argv[2]) / fname
|
||||
keep_custom_fields = str2bool(sys.argv[3])
|
||||
|
||||
tn = TextNormlizer()
|
||||
|
||||
with gzip.open(sys.argv[1], "r") as fin, gzip.open(oname, "w") as fout:
|
||||
for line in fin:
|
||||
cut = json.loads(line)
|
||||
cut["supervisions"][0]["text"] = tn(
|
||||
cut["supervisions"][0]["custom"]["texts"][0]
|
||||
)
|
||||
if not keep_custom_fields:
|
||||
del cut["supervisions"][0]["custom"]
|
||||
del cut["custom"]
|
||||
fout.write((json.dumps(cut) + "\n").encode())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
1
egs/libriheavy/TTS/local/train_bpe_model.py
Symbolic link
1
egs/libriheavy/TTS/local/train_bpe_model.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../libriheavy/ASR/local/train_bpe_model.py
|
131
egs/libriheavy/TTS/prepare.sh
Normal file → Executable file
131
egs/libriheavy/TTS/prepare.sh
Normal file → Executable file
@ -0,0 +1,131 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674
|
||||
export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
|
||||
|
||||
set -eou pipefail
|
||||
|
||||
nj=15
|
||||
stage=-1
|
||||
stop_stage=100
|
||||
|
||||
# We assume dl_dir (download dir) contains the following
|
||||
# directories and files. If not, they will be downloaded
|
||||
# by this script automatically.
|
||||
#
|
||||
# - $dl_dir/librilight
|
||||
# You can find small, medium, large, etc. inside it.
|
||||
#
|
||||
# - $dl_dir/libriheavy
|
||||
# You can find libriheavy_cuts_small.jsonl.gz, libriheavy_cuts_medium.jsonl.gz, etc. inside it.
|
||||
dl_dir=$PWD/download
|
||||
|
||||
. shared/parse_options.sh || exit 1
|
||||
|
||||
# vocab size for sentence piece models.
|
||||
# It will generate data/lang_bpe_xxx,
|
||||
# data/lang_bpe_yyy if the array contains xxx, yyy
|
||||
vocab_sizes=(
|
||||
4000
|
||||
)
|
||||
|
||||
# All files generated by this script are saved in "data".
|
||||
# You can safely remove "data" and rerun this script to regenerate it.
|
||||
mkdir -p data
|
||||
tokens_dir=data/tokens
|
||||
manifests_dir=data/manifests
|
||||
|
||||
log() {
|
||||
# This function is from espnet
|
||||
local fname=${BASH_SOURCE[1]##*/}
|
||||
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
|
||||
}
|
||||
|
||||
log "dl_dir: $dl_dir"
|
||||
|
||||
if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then
|
||||
log "Stage -1: Download audio data."
|
||||
# If you have pre-downloaded it to /path/to/librilight,
|
||||
# you can create a symlink
|
||||
#
|
||||
# ln -sfv /path/to/librilight $dl_dir/librilight
|
||||
#
|
||||
mkdir -p $dl_dir/librilight
|
||||
for subset in small medium large; do
|
||||
log "Downloading ${subset} subset."
|
||||
if [ ! -d $dl_dir/librilight/${subset} ]; then
|
||||
wget -P $dl_dir/librilight -c https://dl.fbaipublicfiles.com/librilight/data/${subset}.tar
|
||||
tar xf $dl_dir/librilight/${subset}.tar -C $dl_dir/librilight
|
||||
else
|
||||
log "Skipping download, ${subset} subset exists."
|
||||
fi
|
||||
done
|
||||
fi
|
||||
|
||||
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
|
||||
log "Stage 0: Download manifests from huggingface."
|
||||
|
||||
# If you have pre-downloaded it to /path/to/libriheavy,
|
||||
# you can create a symlink
|
||||
#
|
||||
# ln -sfv /path/to/libriheavy $dl_dir/libriheavy
|
||||
#
|
||||
mkdir -p $dl_dir/libriheavy
|
||||
for subset in small medium large dev test_clean test_other; do
|
||||
if [ ! -e $dl_dir/libriheavy/libriheavy_cuts_${subset}.jsonl.gz ]; then
|
||||
log "Downloading ${subset} subset."
|
||||
wget -P $dl_dir/libriheavy -c https://huggingface.co/datasets/pkufool/libriheavy/resolve/main/libriheavy_cuts_${subset}.jsonl.gz
|
||||
else
|
||||
log "Skipping download, ${subset} subset exists."
|
||||
fi
|
||||
done
|
||||
fi
|
||||
|
||||
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
|
||||
log "Stage 3: Prepare Libriheavy manifests"
|
||||
mkdir -p $manifests_dir
|
||||
for subset in small medium large dev test_clean test_other; do
|
||||
if [ ! -e $manifests_dir/libriheavy_cuts_${subset}.jsonl.gz ]; then
|
||||
log "Prepare manifest for subset : ${subset}"
|
||||
./local/prepare_manifest.py $dl_dir/libriheavy/libriheavy_cuts_${subset}.jsonl.gz $manifests_dir False
|
||||
fi
|
||||
done
|
||||
fi
|
||||
|
||||
num_per_split=200000
|
||||
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
|
||||
log "Stage 6: Split medium and large subsets."
|
||||
for subset in medium large; do
|
||||
log "Spliting subset : $subset"
|
||||
split_dir=$manifests_dir/libriheavy_${subset}_split
|
||||
mkdir -p $split_dir
|
||||
if [ ! -e $split_dir/.split_completed ]; then
|
||||
lhotse split-lazy $manifests_dir/libriheavy_cuts_${subset}.jsonl.gz $split_dir $num_per_split
|
||||
touch $split_dir/.split_completed
|
||||
fi
|
||||
done
|
||||
fi
|
||||
|
||||
if [ $stage -le 10 ] && [ $stop_stage -ge 10 ]; then
|
||||
log "Stage 10: Train BPE model for unnormalized text"
|
||||
if [ ! -f data/punc_texts ]; then
|
||||
gunzip -c $manifests_dir/libriheavy_cuts_medium.jsonl.gz \
|
||||
| jq '.supervisions[].text' | sed 's/"//;s/\\//g;s/"$//' > data/punc_texts
|
||||
fi
|
||||
for vocab_size in ${vocab_sizes[@]}; do
|
||||
lang_dir=data/lang_punc_bpe_${vocab_size}
|
||||
mkdir -p $lang_dir
|
||||
|
||||
cp data/punc_texts $lang_dir/text
|
||||
|
||||
if [ ! -f $lang_dir/bpe.model ]; then
|
||||
./local/train_bpe_model.py \
|
||||
--lang-dir $lang_dir \
|
||||
--byte-fallback \
|
||||
--vocab-size ${vocab_size} \
|
||||
--byte-fallback \
|
||||
--character-coverage 0.99 \
|
||||
--transcript $lang_dir/text
|
||||
fi
|
||||
done
|
||||
fi
|
Loading…
x
Reference in New Issue
Block a user