mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-10 18:42:19 +00:00
add sliding window
This commit is contained in:
parent
84f8adff32
commit
a6a80896d5
@ -17,6 +17,7 @@
|
|||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
|
import math
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
@ -98,44 +99,90 @@ def get_args():
|
|||||||
help="Stop processing pieces until this number (exclusive).",
|
help="Stop processing pieces until this number (exclusive).",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--window-duration",
|
||||||
|
type=float,
|
||||||
|
default=300.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--shift-duration",
|
||||||
|
type=float,
|
||||||
|
default=250.0,
|
||||||
|
)
|
||||||
|
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
def extract_and_save_one_cuts(
|
def extract_and_save_one_cuts(
|
||||||
raw_cuts_path, cuts_path, model, apply_kmeans, do_normalize, device
|
raw_cuts_path,
|
||||||
|
cuts_path,
|
||||||
|
model,
|
||||||
|
apply_kmeans,
|
||||||
|
do_normalize,
|
||||||
|
window_duration,
|
||||||
|
shift_duration,
|
||||||
):
|
):
|
||||||
logging.info(f"Loading {raw_cuts_path}")
|
logging.info(f"Loading {raw_cuts_path}")
|
||||||
cut_set = CutSet.from_file(raw_cuts_path)
|
cut_set = CutSet.from_file(raw_cuts_path)
|
||||||
|
|
||||||
logging.info("Extracting kmeans")
|
logging.info("Extracting kmeans")
|
||||||
cuts = []
|
cuts = []
|
||||||
|
|
||||||
|
assert window_duration >= shift_duration
|
||||||
|
window_size = int(window_duration * 16000)
|
||||||
|
shift_size = int(shift_duration * 16000)
|
||||||
|
overlap_size = window_size - shift_size
|
||||||
|
out_overlap_size = get_out_length(overlap_size)
|
||||||
|
|
||||||
for cut in tqdm(cut_set):
|
for cut in tqdm(cut_set):
|
||||||
assert cut.sampling_rate == 16000, f"Sampling rate: {cut.sampling_rate}"
|
assert cut.sampling_rate == 16000, f"Sampling rate: {cut.sampling_rate}"
|
||||||
|
|
||||||
audio = cut.load_audio()
|
audio = cut.load_audio()
|
||||||
|
|
||||||
offsets = 0
|
T = audio.shape[1]
|
||||||
if True:
|
start = 0
|
||||||
x = torch.from_numpy(audio).float().to(device)
|
kmeans = []
|
||||||
|
while start < T:
|
||||||
|
real_window_size = min(window_size, T - start)
|
||||||
|
audio_window = audio[:, start : start + real_window_size]
|
||||||
|
|
||||||
with torch.no_grad():
|
x = (
|
||||||
if do_normalize:
|
torch.from_numpy(audio_window)
|
||||||
x = torch.nn.functional.layer_norm(x, x.shape)
|
.float()
|
||||||
|
.to(next(model.parameters()).device)
|
||||||
feature, _ = model.extract_features(
|
|
||||||
source=x,
|
|
||||||
padding_mask=None,
|
|
||||||
mask=False,
|
|
||||||
output_layer=9,
|
|
||||||
)
|
|
||||||
feature = feature.squeeze(0)
|
|
||||||
|
|
||||||
kmeans = " ".join(map(str, apply_kmeans(feature).tolist()))
|
|
||||||
|
|
||||||
cut_with_kmeans = fastcopy(
|
|
||||||
cut,
|
|
||||||
custom={"kmeans": kmeans},
|
|
||||||
)
|
)
|
||||||
cuts.append(cut_with_kmeans)
|
if do_normalize:
|
||||||
|
x = torch.nn.functional.layer_norm(x, x.shape)
|
||||||
|
|
||||||
|
feature, _ = model.extract_features(
|
||||||
|
source=x,
|
||||||
|
padding_mask=None,
|
||||||
|
mask=False,
|
||||||
|
output_layer=9,
|
||||||
|
)
|
||||||
|
feature = feature.squeeze(0)
|
||||||
|
|
||||||
|
current_kmeans = apply_kmeans(feature).tolist()
|
||||||
|
|
||||||
|
if start == 0:
|
||||||
|
kmeans.extend(current_kmeans)
|
||||||
|
else:
|
||||||
|
kmeans.extend(current_kmeans[out_overlap_size:])
|
||||||
|
|
||||||
|
if T - start <= window_size:
|
||||||
|
break
|
||||||
|
|
||||||
|
start += shift_size
|
||||||
|
|
||||||
|
kmeans = " ".join(map(str, kmeans))
|
||||||
|
|
||||||
|
cut_with_kmeans = fastcopy(
|
||||||
|
cut,
|
||||||
|
custom={"kmeans": kmeans},
|
||||||
|
)
|
||||||
|
cuts.append(cut_with_kmeans)
|
||||||
|
|
||||||
cuts = CutSet(cuts)
|
cuts = CutSet(cuts)
|
||||||
|
|
||||||
@ -166,6 +213,9 @@ def extract_kmeans(args):
|
|||||||
model = model[0].eval().to(device)
|
model = model[0].eval().to(device)
|
||||||
do_normalize = task.cfg.normalize
|
do_normalize = task.cfg.normalize
|
||||||
|
|
||||||
|
window_duration = args.window_duration
|
||||||
|
shift_duration = args.shift_duration
|
||||||
|
|
||||||
if args.subset == "small":
|
if args.subset == "small":
|
||||||
cuts_path = output_dir / f"{prefix}_cuts_{args.subset}.jsonl.gz"
|
cuts_path = output_dir / f"{prefix}_cuts_{args.subset}.jsonl.gz"
|
||||||
if cuts_path.is_file():
|
if cuts_path.is_file():
|
||||||
@ -183,7 +233,8 @@ def extract_kmeans(args):
|
|||||||
model,
|
model,
|
||||||
apply_kmeans,
|
apply_kmeans,
|
||||||
do_normalize,
|
do_normalize,
|
||||||
device,
|
window_duration,
|
||||||
|
shift_duration,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
num_digits = 8 # num_digits is fixed by lhotse split-lazy
|
num_digits = 8 # num_digits is fixed by lhotse split-lazy
|
||||||
@ -213,10 +264,19 @@ def extract_kmeans(args):
|
|||||||
model,
|
model,
|
||||||
apply_kmeans,
|
apply_kmeans,
|
||||||
do_normalize,
|
do_normalize,
|
||||||
device,
|
window_duration,
|
||||||
|
shift_duration,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_out_length(T):
|
||||||
|
conv_layers = [(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512, 2, 2)] * 2
|
||||||
|
for i, (out_channels, kernel_size, stride) in enumerate(conv_layers):
|
||||||
|
T = math.floor((T - kernel_size) / stride) + 1
|
||||||
|
|
||||||
|
return max(0, T)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||||
|
|
@ -86,15 +86,15 @@ if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
|
|||||||
wget https://dl.fbaipublicfiles.com/hubert/hubert_base_ls960_L9_km500.bin -P download
|
wget https://dl.fbaipublicfiles.com/hubert/hubert_base_ls960_L9_km500.bin -P download
|
||||||
fi
|
fi
|
||||||
if [ ! -e data/kmeans/.extract_small.done ]; then
|
if [ ! -e data/kmeans/.extract_small.done ]; then
|
||||||
./local/extract_kmeans_from_hubert_base.py --subset small
|
./local/extract_kmeans.py --subset small
|
||||||
touch data/kmeans/.extract_small.done
|
touch data/kmeans/.extract_small.done
|
||||||
fi
|
fi
|
||||||
if [ ! -e data/kmeans/.extract_medium.done ]; then
|
if [ ! -e data/kmeans/.extract_medium.done ]; then
|
||||||
./local/extract_kmeans_from_hubert_base.py --subset medium
|
./local/extract_kmeans.py --subset medium
|
||||||
touch data/kmeans/.extract_medium.done
|
touch data/kmeans/.extract_medium.done
|
||||||
fi
|
fi
|
||||||
if [ ! -e data/kmeans/.extract_large.done ]; then
|
if [ ! -e data/kmeans/.extract_large.done ]; then
|
||||||
./local/extract_kmeans_from_hubert_base.py --subset large
|
./local/extract_kmeans.py --subset large
|
||||||
touch data/kmeans/.extract_large.done
|
touch data/kmeans/.extract_large.done
|
||||||
fi
|
fi
|
||||||
fi
|
fi
|
||||||
|
Loading…
x
Reference in New Issue
Block a user