update prepare.sh

This commit is contained in:
Your Name 2024-10-23 00:10:24 -07:00
parent c9207356af
commit 84f8adff32
2 changed files with 14 additions and 33 deletions

View File

@ -26,7 +26,6 @@ import numpy as np
import torch
from lhotse import CutSet, SupervisionSegment
from lhotse.utils import fastcopy
from silero_vad import get_speech_timestamps, load_silero_vad
from tqdm import tqdm
# Torch's multithreaded behavior needs to be disabled or
@ -82,7 +81,7 @@ def get_args():
parser.add_argument(
"--kmeans-model-path",
type=str,
default="download/hubert_base_ls960_L9_km500.model",
default="download/hubert_base_ls960_L9_km500.bin",
)
parser.add_argument(
@ -103,7 +102,7 @@ def get_args():
def extract_and_save_one_cuts(
raw_cuts_path, cuts_path, model, vad_model, apply_kmeans, do_normalize, device
raw_cuts_path, cuts_path, model, apply_kmeans, do_normalize, device
):
logging.info(f"Loading {raw_cuts_path}")
cut_set = CutSet.from_file(raw_cuts_path)
@ -111,20 +110,11 @@ def extract_and_save_one_cuts(
logging.info("Extracting kmeans")
cuts = []
for cut in tqdm(cut_set):
assert cut.sampling_rate == 16000, f"{cut.sampling_rate}"
assert cut.sampling_rate == 16000, f"Sampling rate: {cut.sampling_rate}"
audio = cut.load_audio()
if audio.shape[-1] > 64 * 16000:
timestamps = get_speech_timestamps(audio, vad_model)
offsets = [i["start"] for i in timestamps]
audios = [audio[:, i["start"] : i["end"]] for i in timestamps]
logging.info(f"Trim audio {cut.id} into {len(audios)} segments")
else:
offsets = [0]
audios = [audio]
seq = 0
for audio, offset in zip(audios, offsets):
offsets = 0
if True:
x = torch.from_numpy(audio).float().to(device)
with torch.no_grad():
@ -141,24 +131,12 @@ def extract_and_save_one_cuts(
kmeans = " ".join(map(str, apply_kmeans(feature).tolist()))
supervision_segment = fastcopy(
cut.supervisions[0],
id=f"{cut.id}-{seq}",
start=0.0,
duration=audio.shape[-1] / 16000,
)
cut_with_kmeans = fastcopy(
cut,
id=f"{cut.id}-{seq}",
start=cut.start + offset / 16000,
duration=audio.shape[-1] / 16000,
supervisions=[supervision_segment],
custom={"kmeans": kmeans},
)
cuts.append(cut_with_kmeans)
seq += 1
cuts = CutSet(cuts)
logging.info(f"Saving to {cuts_path}")
@ -181,7 +159,6 @@ def extract_kmeans(args):
prefix = "librilight"
vad_model = load_silero_vad()
apply_kmeans = ApplyKmeans(args.kmeans_model_path)
model, _, task = fairseq.checkpoint_utils.load_model_ensemble_and_task(
[args.model_path]
@ -204,7 +181,6 @@ def extract_kmeans(args):
raw_cuts_path,
cuts_path,
model,
vad_model,
apply_kmeans,
do_normalize,
device,
@ -235,7 +211,6 @@ def extract_kmeans(args):
raw_cuts_path,
cuts_path,
model,
vad_model,
apply_kmeans,
do_normalize,
device,

View File

@ -5,7 +5,7 @@ export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
set -eou pipefail
nj=15
nj=32
# run step 0 to step 4 by default
stage=0
stop_stage=4
@ -58,13 +58,13 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
mkdir -p data/kmeans
if [ ! -f data/kmeans/.preprocess_complete ]; then
python3 ./local/preprocess_librilight.py
touch data/fbank/.preprocess_complete
touch data/kmeans/.preprocess_complete
fi
fi
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
log "Stage 3: Split medium and large subset into pieces"
num_per_split=200000
num_per_split=2500
split_dir=data/kmeans/medium_split
if [ ! -f $split_dir/.split_completed ]; then
lhotse split-lazy ./data/kmeans/librilight_cuts_medium_raw.jsonl.gz $split_dir $num_per_split
@ -79,6 +79,12 @@ fi
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
log "Stage 4: Extract SSL target for librilight"
if [ ! -e download/hubert_base_ls960.pt ]; then
wget https://dl.fbaipublicfiles.com/hubert/hubert_base_ls960.pt -P download
fi
if [ ! -e download/hubert_base_ls960_L9_km500.bin ]; then
wget https://dl.fbaipublicfiles.com/hubert/hubert_base_ls960_L9_km500.bin -P download
fi
if [ ! -e data/kmeans/.extract_small.done ]; then
./local/extract_kmeans_from_hubert_base.py --subset small
touch data/kmeans/.extract_small.done