mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-13 20:12:24 +00:00
add prepare.sh
This commit is contained in:
parent
512c4831af
commit
fdc0470860
@ -17,55 +17,15 @@
|
|||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
import math
|
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import fairseq
|
|
||||||
import joblib
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
import torch
|
||||||
from lhotse import CutSet, SupervisionSegment
|
from lhotse import CutSet, SupervisionSegment
|
||||||
from lhotse.utils import fastcopy
|
from lhotse.utils import fastcopy
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
# Torch's multithreaded behavior needs to be disabled or
|
|
||||||
# it wastes a lot of CPU and slow things down.
|
|
||||||
# Do this outside of main() in case it needs to take effect
|
|
||||||
# even when we are not invoking the main (e.g. when spawning subprocesses).
|
|
||||||
torch.set_num_threads(1)
|
|
||||||
torch.set_num_interop_threads(1)
|
|
||||||
|
|
||||||
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:512"
|
|
||||||
|
|
||||||
|
|
||||||
class ApplyKmeans(object):
|
|
||||||
def __init__(self, km_path):
|
|
||||||
self.km_model = joblib.load(km_path)
|
|
||||||
self.C_np = self.km_model.cluster_centers_.transpose()
|
|
||||||
self.Cnorm_np = (self.C_np**2).sum(0, keepdims=True)
|
|
||||||
|
|
||||||
self.C = torch.from_numpy(self.C_np)
|
|
||||||
self.Cnorm = torch.from_numpy(self.Cnorm_np)
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
self.C = self.C.cuda()
|
|
||||||
self.Cnorm = self.Cnorm.cuda()
|
|
||||||
|
|
||||||
def __call__(self, x):
|
|
||||||
if isinstance(x, torch.Tensor):
|
|
||||||
dist = (
|
|
||||||
x.pow(2).sum(1, keepdim=True) - 2 * torch.matmul(x, self.C) + self.Cnorm
|
|
||||||
)
|
|
||||||
return dist.argmin(dim=1).cpu().numpy()
|
|
||||||
else:
|
|
||||||
dist = (
|
|
||||||
(x**2).sum(1, keepdims=True)
|
|
||||||
- 2 * np.matmul(x, self.C_np)
|
|
||||||
+ self.Cnorm_np
|
|
||||||
)
|
|
||||||
return np.argmin(dist, axis=1)
|
|
||||||
|
|
||||||
|
|
||||||
def get_args():
|
def get_args():
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
@ -82,12 +42,6 @@ def get_args():
|
|||||||
default="download/hubert_base_ls960.pt",
|
default="download/hubert_base_ls960.pt",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--kmeans-model-path",
|
|
||||||
type=str,
|
|
||||||
default="download/hubert_base_ls960_L9_km500.bin",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--start",
|
"--start",
|
||||||
type=int,
|
type=int,
|
||||||
@ -102,90 +56,27 @@ def get_args():
|
|||||||
help="Stop processing pieces until this number (exclusive).",
|
help="Stop processing pieces until this number (exclusive).",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--window-duration",
|
|
||||||
type=float,
|
|
||||||
default=300.0,
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--shift-duration",
|
|
||||||
type=float,
|
|
||||||
default=250.0,
|
|
||||||
)
|
|
||||||
|
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def extract_and_save_one_cuts(
|
def extract_and_save_one_cuts(
|
||||||
raw_cuts_path,
|
manifests_path,
|
||||||
cuts_path,
|
cuts_path,
|
||||||
model,
|
|
||||||
apply_kmeans,
|
|
||||||
do_normalize,
|
|
||||||
window_duration,
|
|
||||||
shift_duration,
|
|
||||||
):
|
):
|
||||||
logging.info(f"Loading {raw_cuts_path}")
|
logging.info(f"Loading {manifests_path}")
|
||||||
cut_set = CutSet.from_file(raw_cuts_path)
|
cut_set = CutSet.from_file(manifests_path)
|
||||||
|
|
||||||
logging.info("Extracting kmeans")
|
logging.info("Extracting tokens")
|
||||||
cuts = []
|
cuts = []
|
||||||
|
|
||||||
assert window_duration >= shift_duration
|
tokens = " ".join(map(str, tokens))
|
||||||
window_size = int(window_duration * 16000)
|
|
||||||
shift_size = int(shift_duration * 16000)
|
|
||||||
overlap_size = window_size - shift_size
|
|
||||||
out_overlap_size = get_out_length(overlap_size)
|
|
||||||
|
|
||||||
for cut in tqdm(cut_set):
|
cut_with_tokens = fastcopy(
|
||||||
assert cut.sampling_rate == 16000, f"Sampling rate: {cut.sampling_rate}"
|
cut,
|
||||||
|
custom={"tokens": tokens},
|
||||||
audio = cut.load_audio()
|
)
|
||||||
|
cuts.append(cut_with_tokens)
|
||||||
T = audio.shape[1]
|
|
||||||
start = 0
|
|
||||||
kmeans = []
|
|
||||||
while start < T:
|
|
||||||
real_window_size = min(window_size, T - start)
|
|
||||||
audio_window = audio[:, start : start + real_window_size]
|
|
||||||
|
|
||||||
x = (
|
|
||||||
torch.from_numpy(audio_window)
|
|
||||||
.float()
|
|
||||||
.to(next(model.parameters()).device)
|
|
||||||
)
|
|
||||||
if do_normalize:
|
|
||||||
x = torch.nn.functional.layer_norm(x, x.shape)
|
|
||||||
|
|
||||||
feature, _ = model.extract_features(
|
|
||||||
source=x,
|
|
||||||
padding_mask=None,
|
|
||||||
mask=False,
|
|
||||||
output_layer=9,
|
|
||||||
)
|
|
||||||
feature = feature.squeeze(0)
|
|
||||||
|
|
||||||
current_kmeans = apply_kmeans(feature).tolist()
|
|
||||||
|
|
||||||
if start == 0:
|
|
||||||
kmeans.extend(current_kmeans)
|
|
||||||
else:
|
|
||||||
kmeans.extend(current_kmeans[out_overlap_size:])
|
|
||||||
|
|
||||||
if T - start <= window_size:
|
|
||||||
break
|
|
||||||
|
|
||||||
start += shift_size
|
|
||||||
|
|
||||||
kmeans = " ".join(map(str, kmeans))
|
|
||||||
|
|
||||||
cut_with_kmeans = fastcopy(
|
|
||||||
cut,
|
|
||||||
custom={"kmeans": kmeans},
|
|
||||||
)
|
|
||||||
cuts.append(cut_with_kmeans)
|
|
||||||
|
|
||||||
cuts = CutSet(cuts)
|
cuts = CutSet(cuts)
|
||||||
|
|
||||||
@ -193,11 +84,11 @@ def extract_and_save_one_cuts(
|
|||||||
cuts.to_file(cuts_path)
|
cuts.to_file(cuts_path)
|
||||||
|
|
||||||
|
|
||||||
def extract_kmeans(args):
|
def extract_speech_tokens(args):
|
||||||
assert args.subset in ("small", "medium", "large"), f"{args.subset}"
|
assert args.subset in ("small", "medium", "large"), f"{args.subset}"
|
||||||
|
|
||||||
output_dir = (
|
output_dir = (
|
||||||
f"data/kmeans/{args.subset}_split" if args.subset != "small" else "data/kmeans"
|
f"data/tokens/{args.subset}_split" if args.subset != "small" else "data/tokens"
|
||||||
)
|
)
|
||||||
output_dir = Path(output_dir)
|
output_dir = Path(output_dir)
|
||||||
assert output_dir.exists(), f"{output_dir} does not exist!"
|
assert output_dir.exists(), f"{output_dir} does not exist!"
|
||||||
@ -207,17 +98,7 @@ def extract_kmeans(args):
|
|||||||
device = torch.device("cuda", 0)
|
device = torch.device("cuda", 0)
|
||||||
logging.info(f"device: {device}")
|
logging.info(f"device: {device}")
|
||||||
|
|
||||||
prefix = "librilight"
|
prefix = "libriheavy"
|
||||||
|
|
||||||
apply_kmeans = ApplyKmeans(args.kmeans_model_path)
|
|
||||||
model, _, task = fairseq.checkpoint_utils.load_model_ensemble_and_task(
|
|
||||||
[args.model_path]
|
|
||||||
)
|
|
||||||
model = model[0].eval().to(device)
|
|
||||||
do_normalize = task.cfg.normalize
|
|
||||||
|
|
||||||
window_duration = args.window_duration
|
|
||||||
shift_duration = args.shift_duration
|
|
||||||
|
|
||||||
if args.subset == "small":
|
if args.subset == "small":
|
||||||
cuts_path = output_dir / f"{prefix}_cuts_{args.subset}.jsonl.gz"
|
cuts_path = output_dir / f"{prefix}_cuts_{args.subset}.jsonl.gz"
|
||||||
@ -225,16 +106,16 @@ def extract_kmeans(args):
|
|||||||
logging.info(f"{cuts_path} exists - skipping")
|
logging.info(f"{cuts_path} exists - skipping")
|
||||||
return
|
return
|
||||||
|
|
||||||
raw_cuts_path = output_dir / f"{prefix}_cuts_{args.subset}_raw.jsonl.gz"
|
manifests_path = output_dir / f"{prefix}_cuts_{args.subset}.jsonl.gz"
|
||||||
if not raw_cuts_path.is_file():
|
if not manifests_path.is_file():
|
||||||
logging.info(f"{raw_cuts_path} does not exist - skipping it")
|
logging.info(f"{manifests_path} does not exist - skipping it")
|
||||||
return
|
return
|
||||||
|
|
||||||
extract_and_save_one_cuts(
|
extract_and_save_one_cuts(
|
||||||
raw_cuts_path,
|
manifests_path,
|
||||||
cuts_path,
|
cuts_path,
|
||||||
model,
|
model,
|
||||||
apply_kmeans,
|
apply_tokens,
|
||||||
do_normalize,
|
do_normalize,
|
||||||
window_duration,
|
window_duration,
|
||||||
shift_duration,
|
shift_duration,
|
||||||
@ -254,36 +135,23 @@ def extract_kmeans(args):
|
|||||||
logging.info(f"{cuts_path} exists - skipping")
|
logging.info(f"{cuts_path} exists - skipping")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
raw_cuts_path = (
|
manifests_path = (
|
||||||
output_dir / f"{prefix}_cuts_{args.subset}_raw.{idx}.jsonl.gz"
|
output_dir / f"{prefix}_cuts_{args.subset}.{idx}.jsonl.gz"
|
||||||
)
|
)
|
||||||
if not raw_cuts_path.is_file():
|
if not manifests_path.is_file():
|
||||||
logging.info(f"{raw_cuts_path} does not exist - skipping it")
|
logging.info(f"{manifests_path} does not exist - skipping it")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
extract_and_save_one_cuts(
|
extract_and_save_one_cuts(
|
||||||
raw_cuts_path,
|
manifests_path,
|
||||||
cuts_path,
|
cuts_path,
|
||||||
model,
|
|
||||||
apply_kmeans,
|
|
||||||
do_normalize,
|
|
||||||
window_duration,
|
|
||||||
shift_duration,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_out_length(T):
|
|
||||||
conv_layers = [(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512, 2, 2)] * 2
|
|
||||||
for i, (out_channels, kernel_size, stride) in enumerate(conv_layers):
|
|
||||||
T = math.floor((T - kernel_size) / stride) + 1
|
|
||||||
|
|
||||||
return max(0, T)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||||
|
|
||||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||||
args = get_args()
|
args = get_args()
|
||||||
logging.info(vars(args))
|
logging.info(vars(args))
|
||||||
extract_kmeans(args)
|
extract_speech_tokens(args)
|
||||||
|
1
egs/libriheavy/TTS/local/norm_text.py
Symbolic link
1
egs/libriheavy/TTS/local/norm_text.py
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../../ASR/local/norm_text.py
|
@ -81,8 +81,8 @@ if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
|
|||||||
done
|
done
|
||||||
fi
|
fi
|
||||||
|
|
||||||
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
|
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
|
||||||
log "Stage 3: Prepare Libriheavy manifests"
|
log "Stage 1: Prepare Libriheavy manifests"
|
||||||
mkdir -p $manifests_dir
|
mkdir -p $manifests_dir
|
||||||
for subset in small medium large dev test_clean test_other; do
|
for subset in small medium large dev test_clean test_other; do
|
||||||
if [ ! -e $manifests_dir/libriheavy_cuts_${subset}.jsonl.gz ]; then
|
if [ ! -e $manifests_dir/libriheavy_cuts_${subset}.jsonl.gz ]; then
|
||||||
@ -93,8 +93,8 @@ if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
|
|||||||
fi
|
fi
|
||||||
|
|
||||||
num_per_split=200000
|
num_per_split=200000
|
||||||
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
|
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
|
||||||
log "Stage 6: Split medium and large subsets."
|
log "Stage 2: Split medium and large subsets."
|
||||||
for subset in medium large; do
|
for subset in medium large; do
|
||||||
log "Spliting subset : $subset"
|
log "Spliting subset : $subset"
|
||||||
split_dir=$manifests_dir/libriheavy_${subset}_split
|
split_dir=$manifests_dir/libriheavy_${subset}_split
|
||||||
@ -106,25 +106,25 @@ if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
|
|||||||
done
|
done
|
||||||
fi
|
fi
|
||||||
|
|
||||||
if [ $stage -le 10 ] && [ $stop_stage -ge 10 ]; then
|
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
|
||||||
log "Stage 10: Train BPE model for unnormalized text"
|
log "Stage 3: Train BPE model for normalized text"
|
||||||
if [ ! -f data/punc_texts ]; then
|
|
||||||
|
if [ ! -f data/texts ]; then
|
||||||
gunzip -c $manifests_dir/libriheavy_cuts_medium.jsonl.gz \
|
gunzip -c $manifests_dir/libriheavy_cuts_medium.jsonl.gz \
|
||||||
| jq '.supervisions[].text' | sed 's/"//;s/\\//g;s/"$//' > data/punc_texts
|
| jq '.supervisions[].text' | sed 's/"//;s/\\//g;s/"$//' \
|
||||||
|
| ./local/norm_text.py > data/texts
|
||||||
fi
|
fi
|
||||||
|
|
||||||
for vocab_size in ${vocab_sizes[@]}; do
|
for vocab_size in ${vocab_sizes[@]}; do
|
||||||
lang_dir=data/lang_punc_bpe_${vocab_size}
|
lang_dir=data/lang_bpe_${vocab_size}
|
||||||
mkdir -p $lang_dir
|
mkdir -p $lang_dir
|
||||||
|
|
||||||
cp data/punc_texts $lang_dir/text
|
cp data/texts $lang_dir/text
|
||||||
|
|
||||||
if [ ! -f $lang_dir/bpe.model ]; then
|
if [ ! -f $lang_dir/bpe.model ]; then
|
||||||
./local/train_bpe_model.py \
|
./local/train_bpe_model.py \
|
||||||
--lang-dir $lang_dir \
|
--lang-dir $lang_dir \
|
||||||
--byte-fallback \
|
--vocab-size $vocab_size \
|
||||||
--vocab-size ${vocab_size} \
|
|
||||||
--byte-fallback \
|
|
||||||
--character-coverage 0.99 \
|
|
||||||
--transcript $lang_dir/text
|
--transcript $lang_dir/text
|
||||||
fi
|
fi
|
||||||
done
|
done
|
||||||
|
Loading…
x
Reference in New Issue
Block a user