add prepare.sh

This commit is contained in:
yifanyeung 2024-11-02 22:44:43 -07:00
parent 512c4831af
commit fdc0470860
3 changed files with 40 additions and 171 deletions

View File

@ -17,55 +17,15 @@
import argparse
import logging
import math
import os
from pathlib import Path
from typing import Optional
import fairseq
import joblib
import numpy as np
import torch
from lhotse import CutSet, SupervisionSegment
from lhotse.utils import fastcopy
from tqdm import tqdm
# Torch's multithreaded behavior needs to be disabled or
# it wastes a lot of CPU and slow things down.
# Do this outside of main() in case it needs to take effect
# even when we are not invoking the main (e.g. when spawning subprocesses).
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:512"
class ApplyKmeans(object):
def __init__(self, km_path):
self.km_model = joblib.load(km_path)
self.C_np = self.km_model.cluster_centers_.transpose()
self.Cnorm_np = (self.C_np**2).sum(0, keepdims=True)
self.C = torch.from_numpy(self.C_np)
self.Cnorm = torch.from_numpy(self.Cnorm_np)
if torch.cuda.is_available():
self.C = self.C.cuda()
self.Cnorm = self.Cnorm.cuda()
def __call__(self, x):
if isinstance(x, torch.Tensor):
dist = (
x.pow(2).sum(1, keepdim=True) - 2 * torch.matmul(x, self.C) + self.Cnorm
)
return dist.argmin(dim=1).cpu().numpy()
else:
dist = (
(x**2).sum(1, keepdims=True)
- 2 * np.matmul(x, self.C_np)
+ self.Cnorm_np
)
return np.argmin(dist, axis=1)
def get_args():
parser = argparse.ArgumentParser()
@ -82,12 +42,6 @@ def get_args():
default="download/hubert_base_ls960.pt",
)
parser.add_argument(
"--kmeans-model-path",
type=str,
default="download/hubert_base_ls960_L9_km500.bin",
)
parser.add_argument(
"--start",
type=int,
@ -102,90 +56,27 @@ def get_args():
help="Stop processing pieces until this number (exclusive).",
)
parser.add_argument(
"--window-duration",
type=float,
default=300.0,
)
parser.add_argument(
"--shift-duration",
type=float,
default=250.0,
)
return parser.parse_args()
@torch.no_grad()
def extract_and_save_one_cuts(
raw_cuts_path,
manifests_path,
cuts_path,
model,
apply_kmeans,
do_normalize,
window_duration,
shift_duration,
):
logging.info(f"Loading {raw_cuts_path}")
cut_set = CutSet.from_file(raw_cuts_path)
logging.info(f"Loading {manifests_path}")
cut_set = CutSet.from_file(manifests_path)
logging.info("Extracting kmeans")
logging.info("Extracting tokens")
cuts = []
assert window_duration >= shift_duration
window_size = int(window_duration * 16000)
shift_size = int(shift_duration * 16000)
overlap_size = window_size - shift_size
out_overlap_size = get_out_length(overlap_size)
tokens = " ".join(map(str, tokens))
for cut in tqdm(cut_set):
assert cut.sampling_rate == 16000, f"Sampling rate: {cut.sampling_rate}"
audio = cut.load_audio()
T = audio.shape[1]
start = 0
kmeans = []
while start < T:
real_window_size = min(window_size, T - start)
audio_window = audio[:, start : start + real_window_size]
x = (
torch.from_numpy(audio_window)
.float()
.to(next(model.parameters()).device)
)
if do_normalize:
x = torch.nn.functional.layer_norm(x, x.shape)
feature, _ = model.extract_features(
source=x,
padding_mask=None,
mask=False,
output_layer=9,
)
feature = feature.squeeze(0)
current_kmeans = apply_kmeans(feature).tolist()
if start == 0:
kmeans.extend(current_kmeans)
else:
kmeans.extend(current_kmeans[out_overlap_size:])
if T - start <= window_size:
break
start += shift_size
kmeans = " ".join(map(str, kmeans))
cut_with_kmeans = fastcopy(
cut,
custom={"kmeans": kmeans},
)
cuts.append(cut_with_kmeans)
cut_with_tokens = fastcopy(
cut,
custom={"tokens": tokens},
)
cuts.append(cut_with_tokens)
cuts = CutSet(cuts)
@ -193,11 +84,11 @@ def extract_and_save_one_cuts(
cuts.to_file(cuts_path)
def extract_kmeans(args):
def extract_speech_tokens(args):
assert args.subset in ("small", "medium", "large"), f"{args.subset}"
output_dir = (
f"data/kmeans/{args.subset}_split" if args.subset != "small" else "data/kmeans"
f"data/tokens/{args.subset}_split" if args.subset != "small" else "data/tokens"
)
output_dir = Path(output_dir)
assert output_dir.exists(), f"{output_dir} does not exist!"
@ -207,17 +98,7 @@ def extract_kmeans(args):
device = torch.device("cuda", 0)
logging.info(f"device: {device}")
prefix = "librilight"
apply_kmeans = ApplyKmeans(args.kmeans_model_path)
model, _, task = fairseq.checkpoint_utils.load_model_ensemble_and_task(
[args.model_path]
)
model = model[0].eval().to(device)
do_normalize = task.cfg.normalize
window_duration = args.window_duration
shift_duration = args.shift_duration
prefix = "libriheavy"
if args.subset == "small":
cuts_path = output_dir / f"{prefix}_cuts_{args.subset}.jsonl.gz"
@ -225,16 +106,16 @@ def extract_kmeans(args):
logging.info(f"{cuts_path} exists - skipping")
return
raw_cuts_path = output_dir / f"{prefix}_cuts_{args.subset}_raw.jsonl.gz"
if not raw_cuts_path.is_file():
logging.info(f"{raw_cuts_path} does not exist - skipping it")
manifests_path = output_dir / f"{prefix}_cuts_{args.subset}.jsonl.gz"
if not manifests_path.is_file():
logging.info(f"{manifests_path} does not exist - skipping it")
return
extract_and_save_one_cuts(
raw_cuts_path,
manifests_path,
cuts_path,
model,
apply_kmeans,
apply_tokens,
do_normalize,
window_duration,
shift_duration,
@ -254,36 +135,23 @@ def extract_kmeans(args):
logging.info(f"{cuts_path} exists - skipping")
continue
raw_cuts_path = (
output_dir / f"{prefix}_cuts_{args.subset}_raw.{idx}.jsonl.gz"
manifests_path = (
output_dir / f"{prefix}_cuts_{args.subset}.{idx}.jsonl.gz"
)
if not raw_cuts_path.is_file():
logging.info(f"{raw_cuts_path} does not exist - skipping it")
if not manifests_path.is_file():
logging.info(f"{manifests_path} does not exist - skipping it")
continue
extract_and_save_one_cuts(
raw_cuts_path,
manifests_path,
cuts_path,
model,
apply_kmeans,
do_normalize,
window_duration,
shift_duration,
)
def get_out_length(T):
conv_layers = [(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512, 2, 2)] * 2
for i, (out_channels, kernel_size, stride) in enumerate(conv_layers):
T = math.floor((T - kernel_size) / stride) + 1
return max(0, T)
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
args = get_args()
logging.info(vars(args))
extract_kmeans(args)
extract_speech_tokens(args)

View File

@ -0,0 +1 @@
../../ASR/local/norm_text.py

View File

@ -81,8 +81,8 @@ if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
done
fi
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
log "Stage 3: Prepare Libriheavy manifests"
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
log "Stage 1: Prepare Libriheavy manifests"
mkdir -p $manifests_dir
for subset in small medium large dev test_clean test_other; do
if [ ! -e $manifests_dir/libriheavy_cuts_${subset}.jsonl.gz ]; then
@ -93,8 +93,8 @@ if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
fi
num_per_split=200000
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
log "Stage 6: Split medium and large subsets."
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
log "Stage 2: Split medium and large subsets."
for subset in medium large; do
log "Spliting subset : $subset"
split_dir=$manifests_dir/libriheavy_${subset}_split
@ -106,26 +106,26 @@ if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
done
fi
if [ $stage -le 10 ] && [ $stop_stage -ge 10 ]; then
log "Stage 10: Train BPE model for unnormalized text"
if [ ! -f data/punc_texts ]; then
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
log "Stage 3: Train BPE model for normalized text"
if [ ! -f data/texts ]; then
gunzip -c $manifests_dir/libriheavy_cuts_medium.jsonl.gz \
| jq '.supervisions[].text' | sed 's/"//;s/\\//g;s/"$//' > data/punc_texts
| jq '.supervisions[].text' | sed 's/"//;s/\\//g;s/"$//' \
| ./local/norm_text.py > data/texts
fi
for vocab_size in ${vocab_sizes[@]}; do
lang_dir=data/lang_punc_bpe_${vocab_size}
lang_dir=data/lang_bpe_${vocab_size}
mkdir -p $lang_dir
cp data/punc_texts $lang_dir/text
cp data/texts $lang_dir/text
if [ ! -f $lang_dir/bpe.model ]; then
./local/train_bpe_model.py \
--lang-dir $lang_dir \
--byte-fallback \
--vocab-size ${vocab_size} \
--byte-fallback \
--character-coverage 0.99 \
--vocab-size $vocab_size \
--transcript $lang_dir/text
fi
done
fi
fi