add prepare.sh

This commit is contained in:
yifanyeung 2024-11-02 22:44:43 -07:00
parent 512c4831af
commit fdc0470860
3 changed files with 40 additions and 171 deletions

View File

@ -17,55 +17,15 @@
import argparse import argparse
import logging import logging
import math
import os import os
from pathlib import Path from pathlib import Path
from typing import Optional from typing import Optional
import fairseq
import joblib
import numpy as np
import torch import torch
from lhotse import CutSet, SupervisionSegment from lhotse import CutSet, SupervisionSegment
from lhotse.utils import fastcopy from lhotse.utils import fastcopy
from tqdm import tqdm from tqdm import tqdm
# Torch's multithreaded behavior needs to be disabled or
# it wastes a lot of CPU and slow things down.
# Do this outside of main() in case it needs to take effect
# even when we are not invoking the main (e.g. when spawning subprocesses).
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:512"
class ApplyKmeans(object):
def __init__(self, km_path):
self.km_model = joblib.load(km_path)
self.C_np = self.km_model.cluster_centers_.transpose()
self.Cnorm_np = (self.C_np**2).sum(0, keepdims=True)
self.C = torch.from_numpy(self.C_np)
self.Cnorm = torch.from_numpy(self.Cnorm_np)
if torch.cuda.is_available():
self.C = self.C.cuda()
self.Cnorm = self.Cnorm.cuda()
def __call__(self, x):
if isinstance(x, torch.Tensor):
dist = (
x.pow(2).sum(1, keepdim=True) - 2 * torch.matmul(x, self.C) + self.Cnorm
)
return dist.argmin(dim=1).cpu().numpy()
else:
dist = (
(x**2).sum(1, keepdims=True)
- 2 * np.matmul(x, self.C_np)
+ self.Cnorm_np
)
return np.argmin(dist, axis=1)
def get_args(): def get_args():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
@ -82,12 +42,6 @@ def get_args():
default="download/hubert_base_ls960.pt", default="download/hubert_base_ls960.pt",
) )
parser.add_argument(
"--kmeans-model-path",
type=str,
default="download/hubert_base_ls960_L9_km500.bin",
)
parser.add_argument( parser.add_argument(
"--start", "--start",
type=int, type=int,
@ -102,90 +56,27 @@ def get_args():
help="Stop processing pieces until this number (exclusive).", help="Stop processing pieces until this number (exclusive).",
) )
parser.add_argument(
"--window-duration",
type=float,
default=300.0,
)
parser.add_argument(
"--shift-duration",
type=float,
default=250.0,
)
return parser.parse_args() return parser.parse_args()
@torch.no_grad() @torch.no_grad()
def extract_and_save_one_cuts( def extract_and_save_one_cuts(
raw_cuts_path, manifests_path,
cuts_path, cuts_path,
model,
apply_kmeans,
do_normalize,
window_duration,
shift_duration,
): ):
logging.info(f"Loading {raw_cuts_path}") logging.info(f"Loading {manifests_path}")
cut_set = CutSet.from_file(raw_cuts_path) cut_set = CutSet.from_file(manifests_path)
logging.info("Extracting kmeans") logging.info("Extracting tokens")
cuts = [] cuts = []
assert window_duration >= shift_duration tokens = " ".join(map(str, tokens))
window_size = int(window_duration * 16000)
shift_size = int(shift_duration * 16000)
overlap_size = window_size - shift_size
out_overlap_size = get_out_length(overlap_size)
for cut in tqdm(cut_set): cut_with_tokens = fastcopy(
assert cut.sampling_rate == 16000, f"Sampling rate: {cut.sampling_rate}" cut,
custom={"tokens": tokens},
audio = cut.load_audio() )
cuts.append(cut_with_tokens)
T = audio.shape[1]
start = 0
kmeans = []
while start < T:
real_window_size = min(window_size, T - start)
audio_window = audio[:, start : start + real_window_size]
x = (
torch.from_numpy(audio_window)
.float()
.to(next(model.parameters()).device)
)
if do_normalize:
x = torch.nn.functional.layer_norm(x, x.shape)
feature, _ = model.extract_features(
source=x,
padding_mask=None,
mask=False,
output_layer=9,
)
feature = feature.squeeze(0)
current_kmeans = apply_kmeans(feature).tolist()
if start == 0:
kmeans.extend(current_kmeans)
else:
kmeans.extend(current_kmeans[out_overlap_size:])
if T - start <= window_size:
break
start += shift_size
kmeans = " ".join(map(str, kmeans))
cut_with_kmeans = fastcopy(
cut,
custom={"kmeans": kmeans},
)
cuts.append(cut_with_kmeans)
cuts = CutSet(cuts) cuts = CutSet(cuts)
@ -193,11 +84,11 @@ def extract_and_save_one_cuts(
cuts.to_file(cuts_path) cuts.to_file(cuts_path)
def extract_kmeans(args): def extract_speech_tokens(args):
assert args.subset in ("small", "medium", "large"), f"{args.subset}" assert args.subset in ("small", "medium", "large"), f"{args.subset}"
output_dir = ( output_dir = (
f"data/kmeans/{args.subset}_split" if args.subset != "small" else "data/kmeans" f"data/tokens/{args.subset}_split" if args.subset != "small" else "data/tokens"
) )
output_dir = Path(output_dir) output_dir = Path(output_dir)
assert output_dir.exists(), f"{output_dir} does not exist!" assert output_dir.exists(), f"{output_dir} does not exist!"
@ -207,17 +98,7 @@ def extract_kmeans(args):
device = torch.device("cuda", 0) device = torch.device("cuda", 0)
logging.info(f"device: {device}") logging.info(f"device: {device}")
prefix = "librilight" prefix = "libriheavy"
apply_kmeans = ApplyKmeans(args.kmeans_model_path)
model, _, task = fairseq.checkpoint_utils.load_model_ensemble_and_task(
[args.model_path]
)
model = model[0].eval().to(device)
do_normalize = task.cfg.normalize
window_duration = args.window_duration
shift_duration = args.shift_duration
if args.subset == "small": if args.subset == "small":
cuts_path = output_dir / f"{prefix}_cuts_{args.subset}.jsonl.gz" cuts_path = output_dir / f"{prefix}_cuts_{args.subset}.jsonl.gz"
@ -225,16 +106,16 @@ def extract_kmeans(args):
logging.info(f"{cuts_path} exists - skipping") logging.info(f"{cuts_path} exists - skipping")
return return
raw_cuts_path = output_dir / f"{prefix}_cuts_{args.subset}_raw.jsonl.gz" manifests_path = output_dir / f"{prefix}_cuts_{args.subset}.jsonl.gz"
if not raw_cuts_path.is_file(): if not manifests_path.is_file():
logging.info(f"{raw_cuts_path} does not exist - skipping it") logging.info(f"{manifests_path} does not exist - skipping it")
return return
extract_and_save_one_cuts( extract_and_save_one_cuts(
raw_cuts_path, manifests_path,
cuts_path, cuts_path,
model, model,
apply_kmeans, apply_tokens,
do_normalize, do_normalize,
window_duration, window_duration,
shift_duration, shift_duration,
@ -254,36 +135,23 @@ def extract_kmeans(args):
logging.info(f"{cuts_path} exists - skipping") logging.info(f"{cuts_path} exists - skipping")
continue continue
raw_cuts_path = ( manifests_path = (
output_dir / f"{prefix}_cuts_{args.subset}_raw.{idx}.jsonl.gz" output_dir / f"{prefix}_cuts_{args.subset}.{idx}.jsonl.gz"
) )
if not raw_cuts_path.is_file(): if not manifests_path.is_file():
logging.info(f"{raw_cuts_path} does not exist - skipping it") logging.info(f"{manifests_path} does not exist - skipping it")
continue continue
extract_and_save_one_cuts( extract_and_save_one_cuts(
raw_cuts_path, manifests_path,
cuts_path, cuts_path,
model,
apply_kmeans,
do_normalize,
window_duration,
shift_duration,
) )
def get_out_length(T):
conv_layers = [(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512, 2, 2)] * 2
for i, (out_channels, kernel_size, stride) in enumerate(conv_layers):
T = math.floor((T - kernel_size) / stride) + 1
return max(0, T)
if __name__ == "__main__": if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO) logging.basicConfig(format=formatter, level=logging.INFO)
args = get_args() args = get_args()
logging.info(vars(args)) logging.info(vars(args))
extract_kmeans(args) extract_speech_tokens(args)

View File

@ -0,0 +1 @@
../../ASR/local/norm_text.py

View File

@ -81,8 +81,8 @@ if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
done done
fi fi
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
log "Stage 3: Prepare Libriheavy manifests" log "Stage 1: Prepare Libriheavy manifests"
mkdir -p $manifests_dir mkdir -p $manifests_dir
for subset in small medium large dev test_clean test_other; do for subset in small medium large dev test_clean test_other; do
if [ ! -e $manifests_dir/libriheavy_cuts_${subset}.jsonl.gz ]; then if [ ! -e $manifests_dir/libriheavy_cuts_${subset}.jsonl.gz ]; then
@ -93,8 +93,8 @@ if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
fi fi
num_per_split=200000 num_per_split=200000
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
log "Stage 6: Split medium and large subsets." log "Stage 2: Split medium and large subsets."
for subset in medium large; do for subset in medium large; do
log "Spliting subset : $subset" log "Spliting subset : $subset"
split_dir=$manifests_dir/libriheavy_${subset}_split split_dir=$manifests_dir/libriheavy_${subset}_split
@ -106,26 +106,26 @@ if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
done done
fi fi
if [ $stage -le 10 ] && [ $stop_stage -ge 10 ]; then if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
log "Stage 10: Train BPE model for unnormalized text" log "Stage 3: Train BPE model for normalized text"
if [ ! -f data/punc_texts ]; then
if [ ! -f data/texts ]; then
gunzip -c $manifests_dir/libriheavy_cuts_medium.jsonl.gz \ gunzip -c $manifests_dir/libriheavy_cuts_medium.jsonl.gz \
| jq '.supervisions[].text' | sed 's/"//;s/\\//g;s/"$//' > data/punc_texts | jq '.supervisions[].text' | sed 's/"//;s/\\//g;s/"$//' \
| ./local/norm_text.py > data/texts
fi fi
for vocab_size in ${vocab_sizes[@]}; do for vocab_size in ${vocab_sizes[@]}; do
lang_dir=data/lang_punc_bpe_${vocab_size} lang_dir=data/lang_bpe_${vocab_size}
mkdir -p $lang_dir mkdir -p $lang_dir
cp data/punc_texts $lang_dir/text cp data/texts $lang_dir/text
if [ ! -f $lang_dir/bpe.model ]; then if [ ! -f $lang_dir/bpe.model ]; then
./local/train_bpe_model.py \ ./local/train_bpe_model.py \
--lang-dir $lang_dir \ --lang-dir $lang_dir \
--byte-fallback \ --vocab-size $vocab_size \
--vocab-size ${vocab_size} \
--byte-fallback \
--character-coverage 0.99 \
--transcript $lang_dir/text --transcript $lang_dir/text
fi fi
done done
fi fi