mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-10 10:32:17 +00:00
update prepare.sh
This commit is contained in:
parent
c9207356af
commit
84f8adff32
@ -26,7 +26,6 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
from lhotse import CutSet, SupervisionSegment
|
from lhotse import CutSet, SupervisionSegment
|
||||||
from lhotse.utils import fastcopy
|
from lhotse.utils import fastcopy
|
||||||
from silero_vad import get_speech_timestamps, load_silero_vad
|
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
# Torch's multithreaded behavior needs to be disabled or
|
# Torch's multithreaded behavior needs to be disabled or
|
||||||
@ -82,7 +81,7 @@ def get_args():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--kmeans-model-path",
|
"--kmeans-model-path",
|
||||||
type=str,
|
type=str,
|
||||||
default="download/hubert_base_ls960_L9_km500.model",
|
default="download/hubert_base_ls960_L9_km500.bin",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -103,7 +102,7 @@ def get_args():
|
|||||||
|
|
||||||
|
|
||||||
def extract_and_save_one_cuts(
|
def extract_and_save_one_cuts(
|
||||||
raw_cuts_path, cuts_path, model, vad_model, apply_kmeans, do_normalize, device
|
raw_cuts_path, cuts_path, model, apply_kmeans, do_normalize, device
|
||||||
):
|
):
|
||||||
logging.info(f"Loading {raw_cuts_path}")
|
logging.info(f"Loading {raw_cuts_path}")
|
||||||
cut_set = CutSet.from_file(raw_cuts_path)
|
cut_set = CutSet.from_file(raw_cuts_path)
|
||||||
@ -111,20 +110,11 @@ def extract_and_save_one_cuts(
|
|||||||
logging.info("Extracting kmeans")
|
logging.info("Extracting kmeans")
|
||||||
cuts = []
|
cuts = []
|
||||||
for cut in tqdm(cut_set):
|
for cut in tqdm(cut_set):
|
||||||
assert cut.sampling_rate == 16000, f"{cut.sampling_rate}"
|
assert cut.sampling_rate == 16000, f"Sampling rate: {cut.sampling_rate}"
|
||||||
audio = cut.load_audio()
|
audio = cut.load_audio()
|
||||||
|
|
||||||
if audio.shape[-1] > 64 * 16000:
|
offsets = 0
|
||||||
timestamps = get_speech_timestamps(audio, vad_model)
|
if True:
|
||||||
offsets = [i["start"] for i in timestamps]
|
|
||||||
audios = [audio[:, i["start"] : i["end"]] for i in timestamps]
|
|
||||||
logging.info(f"Trim audio {cut.id} into {len(audios)} segments")
|
|
||||||
else:
|
|
||||||
offsets = [0]
|
|
||||||
audios = [audio]
|
|
||||||
|
|
||||||
seq = 0
|
|
||||||
for audio, offset in zip(audios, offsets):
|
|
||||||
x = torch.from_numpy(audio).float().to(device)
|
x = torch.from_numpy(audio).float().to(device)
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
@ -141,24 +131,12 @@ def extract_and_save_one_cuts(
|
|||||||
|
|
||||||
kmeans = " ".join(map(str, apply_kmeans(feature).tolist()))
|
kmeans = " ".join(map(str, apply_kmeans(feature).tolist()))
|
||||||
|
|
||||||
supervision_segment = fastcopy(
|
|
||||||
cut.supervisions[0],
|
|
||||||
id=f"{cut.id}-{seq}",
|
|
||||||
start=0.0,
|
|
||||||
duration=audio.shape[-1] / 16000,
|
|
||||||
)
|
|
||||||
cut_with_kmeans = fastcopy(
|
cut_with_kmeans = fastcopy(
|
||||||
cut,
|
cut,
|
||||||
id=f"{cut.id}-{seq}",
|
|
||||||
start=cut.start + offset / 16000,
|
|
||||||
duration=audio.shape[-1] / 16000,
|
|
||||||
supervisions=[supervision_segment],
|
|
||||||
custom={"kmeans": kmeans},
|
custom={"kmeans": kmeans},
|
||||||
)
|
)
|
||||||
cuts.append(cut_with_kmeans)
|
cuts.append(cut_with_kmeans)
|
||||||
|
|
||||||
seq += 1
|
|
||||||
|
|
||||||
cuts = CutSet(cuts)
|
cuts = CutSet(cuts)
|
||||||
|
|
||||||
logging.info(f"Saving to {cuts_path}")
|
logging.info(f"Saving to {cuts_path}")
|
||||||
@ -181,7 +159,6 @@ def extract_kmeans(args):
|
|||||||
|
|
||||||
prefix = "librilight"
|
prefix = "librilight"
|
||||||
|
|
||||||
vad_model = load_silero_vad()
|
|
||||||
apply_kmeans = ApplyKmeans(args.kmeans_model_path)
|
apply_kmeans = ApplyKmeans(args.kmeans_model_path)
|
||||||
model, _, task = fairseq.checkpoint_utils.load_model_ensemble_and_task(
|
model, _, task = fairseq.checkpoint_utils.load_model_ensemble_and_task(
|
||||||
[args.model_path]
|
[args.model_path]
|
||||||
@ -204,7 +181,6 @@ def extract_kmeans(args):
|
|||||||
raw_cuts_path,
|
raw_cuts_path,
|
||||||
cuts_path,
|
cuts_path,
|
||||||
model,
|
model,
|
||||||
vad_model,
|
|
||||||
apply_kmeans,
|
apply_kmeans,
|
||||||
do_normalize,
|
do_normalize,
|
||||||
device,
|
device,
|
||||||
@ -235,7 +211,6 @@ def extract_kmeans(args):
|
|||||||
raw_cuts_path,
|
raw_cuts_path,
|
||||||
cuts_path,
|
cuts_path,
|
||||||
model,
|
model,
|
||||||
vad_model,
|
|
||||||
apply_kmeans,
|
apply_kmeans,
|
||||||
do_normalize,
|
do_normalize,
|
||||||
device,
|
device,
|
||||||
|
@ -5,7 +5,7 @@ export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
|
|||||||
|
|
||||||
set -eou pipefail
|
set -eou pipefail
|
||||||
|
|
||||||
nj=15
|
nj=32
|
||||||
# run step 0 to step 4 by default
|
# run step 0 to step 4 by default
|
||||||
stage=0
|
stage=0
|
||||||
stop_stage=4
|
stop_stage=4
|
||||||
@ -58,13 +58,13 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
|
|||||||
mkdir -p data/kmeans
|
mkdir -p data/kmeans
|
||||||
if [ ! -f data/kmeans/.preprocess_complete ]; then
|
if [ ! -f data/kmeans/.preprocess_complete ]; then
|
||||||
python3 ./local/preprocess_librilight.py
|
python3 ./local/preprocess_librilight.py
|
||||||
touch data/fbank/.preprocess_complete
|
touch data/kmeans/.preprocess_complete
|
||||||
fi
|
fi
|
||||||
fi
|
fi
|
||||||
|
|
||||||
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
|
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
|
||||||
log "Stage 3: Split medium and large subset into pieces"
|
log "Stage 3: Split medium and large subset into pieces"
|
||||||
num_per_split=200000
|
num_per_split=2500
|
||||||
split_dir=data/kmeans/medium_split
|
split_dir=data/kmeans/medium_split
|
||||||
if [ ! -f $split_dir/.split_completed ]; then
|
if [ ! -f $split_dir/.split_completed ]; then
|
||||||
lhotse split-lazy ./data/kmeans/librilight_cuts_medium_raw.jsonl.gz $split_dir $num_per_split
|
lhotse split-lazy ./data/kmeans/librilight_cuts_medium_raw.jsonl.gz $split_dir $num_per_split
|
||||||
@ -79,6 +79,12 @@ fi
|
|||||||
|
|
||||||
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
|
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
|
||||||
log "Stage 4: Extract SSL target for librilight"
|
log "Stage 4: Extract SSL target for librilight"
|
||||||
|
if [ ! -e download/hubert_base_ls960.pt ]; then
|
||||||
|
wget https://dl.fbaipublicfiles.com/hubert/hubert_base_ls960.pt -P download
|
||||||
|
fi
|
||||||
|
if [ ! -e download/hubert_base_ls960_L9_km500.bin ]; then
|
||||||
|
wget https://dl.fbaipublicfiles.com/hubert/hubert_base_ls960_L9_km500.bin -P download
|
||||||
|
fi
|
||||||
if [ ! -e data/kmeans/.extract_small.done ]; then
|
if [ ! -e data/kmeans/.extract_small.done ]; then
|
||||||
./local/extract_kmeans_from_hubert_base.py --subset small
|
./local/extract_kmeans_from_hubert_base.py --subset small
|
||||||
touch data/kmeans/.extract_small.done
|
touch data/kmeans/.extract_small.done
|
||||||
|
Loading…
x
Reference in New Issue
Block a user