mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-19 05:54:20 +00:00
initial commit for SURT AMI recipe
This commit is contained in:
parent
d6b88aaa98
commit
14818f5dd8
78
egs/ami/SURT/local/add_source_feats.py
Executable file
78
egs/ami/SURT/local/add_source_feats.py
Executable file
@ -0,0 +1,78 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2022 Johns Hopkins University (authors: Desh Raj)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
"""
|
||||
This file adds source features as temporal arrays to the mixture manifests.
|
||||
It looks for manifests in the directory data/manifests.
|
||||
"""
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
from lhotse import CutSet, LilcomChunkyWriter, load_manifest, load_manifest_lazy
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
def add_source_feats():
|
||||
src_dir = Path("data/manifests")
|
||||
output_dir = Path("data/fbank")
|
||||
|
||||
logging.info("Reading mixed cuts")
|
||||
mixed_cuts_clean = load_manifest_lazy(src_dir / "cuts_train_clean.jsonl.gz")
|
||||
mixed_cuts_reverb = load_manifest_lazy(src_dir / "cuts_train_reverb.jsonl.gz")
|
||||
|
||||
logging.info("Reading source cuts")
|
||||
source_cuts = load_manifest(src_dir / "ihm_cuts_train_trimmed.jsonl.gz")
|
||||
|
||||
logging.info("Adding source features to the mixed cuts")
|
||||
pbar = tqdm(total=len(mixed_cuts_clean), desc="Adding source features")
|
||||
with CutSet.open_writer(
|
||||
src_dir / "cuts_train_clean_sources.jsonl.gz"
|
||||
) as cut_writer_clean, CutSet.open_writer(
|
||||
src_dir / "cuts_train_reverb_sources.jsonl.gz"
|
||||
) as cut_writer_reverb, LilcomChunkyWriter(
|
||||
output_dir / "feats_train_clean_sources"
|
||||
) as source_feat_writer:
|
||||
for cut_clean, cut_reverb in zip(mixed_cuts_clean, mixed_cuts_reverb):
|
||||
assert cut_reverb.id == cut_clean.id + "_rvb"
|
||||
source_feats = []
|
||||
source_feat_offsets = []
|
||||
cur_offset = 0
|
||||
for sup in sorted(
|
||||
cut_clean.supervisions, key=lambda s: (s.start, s.speaker)
|
||||
):
|
||||
source_cut = source_cuts[sup.id]
|
||||
source_feats.append(source_cut.load_features())
|
||||
source_feat_offsets.append(cur_offset)
|
||||
cur_offset += source_cut.num_frames
|
||||
cut_clean.source_feats = source_feat_writer.store_array(
|
||||
cut_clean.id, np.concatenate(source_feats, axis=0)
|
||||
)
|
||||
cut_clean.source_feat_offsets = source_feat_offsets
|
||||
cut_writer_clean.write(cut_clean)
|
||||
# Also write the reverb cut
|
||||
cut_reverb.source_feats = cut_clean.source_feats
|
||||
cut_reverb.source_feat_offsets = cut_clean.source_feat_offsets
|
||||
cut_writer_reverb.write(cut_reverb)
|
||||
pbar.update(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
add_source_feats()
|
185
egs/ami/SURT/local/compute_fbank_aimix.py
Executable file
185
egs/ami/SURT/local/compute_fbank_aimix.py
Executable file
@ -0,0 +1,185 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2022 Johns Hopkins University (authors: Desh Raj)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
"""
|
||||
This file computes fbank features of the synthetically mixed AMI and ICSI
|
||||
train set.
|
||||
It looks for manifests in the directory data/manifests.
|
||||
|
||||
The generated fbank features are saved in data/fbank.
|
||||
"""
|
||||
import logging
|
||||
import random
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import torch.multiprocessing
|
||||
import torchaudio
|
||||
from lhotse import (
|
||||
AudioSource,
|
||||
LilcomChunkyWriter,
|
||||
Recording,
|
||||
load_manifest,
|
||||
load_manifest_lazy,
|
||||
)
|
||||
from lhotse.audio import set_ffmpeg_torchaudio_info_enabled
|
||||
from lhotse.cut import MixedCut, MixTrack, MultiCut
|
||||
from lhotse.features.kaldifeat import (
|
||||
KaldifeatFbank,
|
||||
KaldifeatFbankConfig,
|
||||
KaldifeatFrameOptions,
|
||||
KaldifeatMelOptions,
|
||||
)
|
||||
from lhotse.utils import fix_random_seed, uuid4
|
||||
from tqdm import tqdm
|
||||
|
||||
# Torch's multithreaded behavior needs to be disabled or
|
||||
# it wastes a lot of CPU and slow things down.
|
||||
# Do this outside of main() in case it needs to take effect
|
||||
# even when we are not invoking the main (e.g. when spawning subprocesses).
|
||||
torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
torch.multiprocessing.set_sharing_strategy("file_system")
|
||||
torchaudio.set_audio_backend("soundfile")
|
||||
set_ffmpeg_torchaudio_info_enabled(False)
|
||||
|
||||
|
||||
def compute_fbank_aimix():
|
||||
src_dir = Path("data/manifests")
|
||||
output_dir = Path("data/fbank")
|
||||
|
||||
sampling_rate = 16000
|
||||
num_mel_bins = 80
|
||||
|
||||
extractor = KaldifeatFbank(
|
||||
KaldifeatFbankConfig(
|
||||
frame_opts=KaldifeatFrameOptions(sampling_rate=sampling_rate),
|
||||
mel_opts=KaldifeatMelOptions(num_bins=num_mel_bins),
|
||||
device="cuda",
|
||||
)
|
||||
)
|
||||
|
||||
logging.info("Reading manifests")
|
||||
train_cuts = load_manifest_lazy(src_dir / "ai-mix_cuts_clean_full.jsonl.gz")
|
||||
|
||||
# only uses RIRs and noises from REVERB challenge
|
||||
real_rirs = load_manifest(src_dir / "real-rir_recordings_all.jsonl.gz").filter(
|
||||
lambda r: "RVB2014" in r.id
|
||||
)
|
||||
noises = load_manifest(src_dir / "iso-noise_recordings_all.jsonl.gz").filter(
|
||||
lambda r: "RVB2014" in r.id
|
||||
)
|
||||
|
||||
# Apply perturbation to the training cuts
|
||||
logging.info("Applying perturbation to the training cuts")
|
||||
train_cuts_rvb = train_cuts.map(
|
||||
lambda c: augment(
|
||||
c, perturb_snr=True, rirs=real_rirs, noises=noises, perturb_loudness=True
|
||||
)
|
||||
)
|
||||
|
||||
logging.info("Extracting fbank features for training cuts")
|
||||
_ = train_cuts.compute_and_store_features_batch(
|
||||
extractor=extractor,
|
||||
storage_path=output_dir / "ai-mix_feats_clean",
|
||||
manifest_path=src_dir / "cuts_train_clean.jsonl.gz",
|
||||
batch_duration=5000,
|
||||
num_workers=4,
|
||||
storage_type=LilcomChunkyWriter,
|
||||
overwrite=True,
|
||||
)
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
_ = train_cuts_rvb.compute_and_store_features_batch(
|
||||
extractor=extractor,
|
||||
storage_path=output_dir / "ai-mix_feats_reverb",
|
||||
manifest_path=src_dir / "cuts_train_reverb.jsonl.gz",
|
||||
batch_duration=5000,
|
||||
num_workers=4,
|
||||
storage_type=LilcomChunkyWriter,
|
||||
overwrite=True,
|
||||
)
|
||||
|
||||
|
||||
def augment(cut, perturb_snr=False, rirs=None, noises=None, perturb_loudness=False):
|
||||
"""
|
||||
Given a mixed cut, this function optionally applies the following augmentations:
|
||||
- Perturbing the SNRs of the tracks (in range [-5, 5] dB)
|
||||
- Reverberation using a randomly selected RIR
|
||||
- Adding noise
|
||||
- Perturbing the loudness (in range [-20, -25] dB)
|
||||
"""
|
||||
out_cut = cut.drop_features()
|
||||
|
||||
# Perturb the SNRs (optional)
|
||||
if perturb_snr:
|
||||
snrs = [random.uniform(-5, 5) for _ in range(len(cut.tracks))]
|
||||
for i, (track, snr) in enumerate(zip(out_cut.tracks, snrs)):
|
||||
if i == 0:
|
||||
# Skip the first track since it is the reference
|
||||
continue
|
||||
track.snr = snr
|
||||
|
||||
# Reverberate the cut (optional)
|
||||
if rirs is not None:
|
||||
# Select an RIR at random
|
||||
rir = random.choice(rirs)
|
||||
# Select a channel at random
|
||||
rir_channel = random.choice(list(range(rir.num_channels)))
|
||||
# Reverberate the cut
|
||||
out_cut = out_cut.reverb_rir(rir_recording=rir, rir_channels=[rir_channel])
|
||||
|
||||
# Add noise (optional)
|
||||
if noises is not None:
|
||||
# Select a noise recording at random
|
||||
noise = random.choice(noises).to_cut()
|
||||
if isinstance(noise, MultiCut):
|
||||
noise = noise.to_mono()[0]
|
||||
# Select an SNR at random
|
||||
snr = random.uniform(10, 30)
|
||||
# Repeat the noise to match the duration of the cut
|
||||
noise = repeat_cut(noise, out_cut.duration)
|
||||
out_cut = MixedCut(
|
||||
id=out_cut.id,
|
||||
tracks=[
|
||||
MixTrack(cut=out_cut, type="MixedCut"),
|
||||
MixTrack(cut=noise, type="DataCut", snr=snr),
|
||||
],
|
||||
)
|
||||
|
||||
# Perturb the loudness (optional)
|
||||
if perturb_loudness:
|
||||
target_loudness = random.uniform(-20, -25)
|
||||
out_cut = out_cut.normalize_loudness(target_loudness, mix_first=True)
|
||||
return out_cut
|
||||
|
||||
|
||||
def repeat_cut(cut, duration):
|
||||
while cut.duration < duration:
|
||||
cut = cut.mix(cut, offset_other_by=cut.duration)
|
||||
return cut.truncate(duration=duration)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
|
||||
fix_random_seed(42)
|
||||
compute_fbank_aimix()
|
94
egs/ami/SURT/local/compute_fbank_ami.py
Executable file
94
egs/ami/SURT/local/compute_fbank_ami.py
Executable file
@ -0,0 +1,94 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2022 Johns Hopkins University (authors: Desh Raj)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
"""
|
||||
This file computes fbank features of the AMI dataset.
|
||||
We compute features for full recordings (i.e., without trimming to supervisions).
|
||||
This way we can create arbitrary segmentations later.
|
||||
|
||||
The generated fbank features are saved in data/fbank.
|
||||
"""
|
||||
import logging
|
||||
import math
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import torch.multiprocessing
|
||||
from lhotse import CutSet, LilcomChunkyWriter
|
||||
from lhotse.features.kaldifeat import (
|
||||
KaldifeatFbank,
|
||||
KaldifeatFbankConfig,
|
||||
KaldifeatFrameOptions,
|
||||
KaldifeatMelOptions,
|
||||
)
|
||||
from lhotse.recipes.utils import read_manifests_if_cached
|
||||
|
||||
# Torch's multithreaded behavior needs to be disabled or
|
||||
# it wastes a lot of CPU and slow things down.
|
||||
# Do this outside of main() in case it needs to take effect
|
||||
# even when we are not invoking the main (e.g. when spawning subprocesses).
|
||||
torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
torch.multiprocessing.set_sharing_strategy("file_system")
|
||||
|
||||
|
||||
def compute_fbank_ami():
|
||||
src_dir = Path("data/manifests")
|
||||
output_dir = Path("data/fbank")
|
||||
|
||||
sampling_rate = 16000
|
||||
num_mel_bins = 80
|
||||
|
||||
extractor = KaldifeatFbank(
|
||||
KaldifeatFbankConfig(
|
||||
frame_opts=KaldifeatFrameOptions(sampling_rate=sampling_rate),
|
||||
mel_opts=KaldifeatMelOptions(num_bins=num_mel_bins),
|
||||
device="cuda",
|
||||
)
|
||||
)
|
||||
|
||||
logging.info("Reading manifests")
|
||||
manifests = {}
|
||||
for part in ["ihm-mix", "sdm", "mdm8-bf"]:
|
||||
manifests[part] = read_manifests_if_cached(
|
||||
dataset_parts=["train", "dev", "test"],
|
||||
output_dir=src_dir,
|
||||
prefix=f"ami-{part}",
|
||||
suffix="jsonl.gz",
|
||||
)
|
||||
|
||||
for part in ["ihm-mix", "sdm", "mdm8-bf"]:
|
||||
for split in ["train", "dev", "test"]:
|
||||
logging.info(f"Processing {part} {split}")
|
||||
cuts = CutSet.from_manifests(
|
||||
**manifests[part][split]
|
||||
).compute_and_store_features_batch(
|
||||
extractor=extractor,
|
||||
storage_path=output_dir / f"ami-{part}_{split}_feats",
|
||||
manifest_path=src_dir / f"cuts_ami-{part}_{split}.jsonl.gz",
|
||||
batch_duration=5000,
|
||||
num_workers=4,
|
||||
storage_type=LilcomChunkyWriter,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
|
||||
compute_fbank_ami()
|
95
egs/ami/SURT/local/compute_fbank_icsi.py
Executable file
95
egs/ami/SURT/local/compute_fbank_icsi.py
Executable file
@ -0,0 +1,95 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2022 Johns Hopkins University (authors: Desh Raj)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
"""
|
||||
This file computes fbank features of the ICSI dataset.
|
||||
We compute features for full recordings (i.e., without trimming to supervisions).
|
||||
This way we can create arbitrary segmentations later.
|
||||
|
||||
The generated fbank features are saved in data/fbank.
|
||||
"""
|
||||
import logging
|
||||
import math
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import torch.multiprocessing
|
||||
from lhotse import CutSet, LilcomChunkyWriter
|
||||
from lhotse.features.kaldifeat import (
|
||||
KaldifeatFbank,
|
||||
KaldifeatFbankConfig,
|
||||
KaldifeatFrameOptions,
|
||||
KaldifeatMelOptions,
|
||||
)
|
||||
from lhotse.recipes.utils import read_manifests_if_cached
|
||||
|
||||
# Torch's multithreaded behavior needs to be disabled or
|
||||
# it wastes a lot of CPU and slow things down.
|
||||
# Do this outside of main() in case it needs to take effect
|
||||
# even when we are not invoking the main (e.g. when spawning subprocesses).
|
||||
torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
torch.multiprocessing.set_sharing_strategy("file_system")
|
||||
|
||||
|
||||
def compute_fbank_icsi():
|
||||
src_dir = Path("data/manifests")
|
||||
output_dir = Path("data/fbank")
|
||||
|
||||
sampling_rate = 16000
|
||||
num_mel_bins = 80
|
||||
|
||||
extractor = KaldifeatFbank(
|
||||
KaldifeatFbankConfig(
|
||||
frame_opts=KaldifeatFrameOptions(sampling_rate=sampling_rate),
|
||||
mel_opts=KaldifeatMelOptions(num_bins=num_mel_bins),
|
||||
device="cuda",
|
||||
)
|
||||
)
|
||||
|
||||
logging.info("Reading manifests")
|
||||
manifests = {}
|
||||
for part in ["ihm-mix", "sdm"]:
|
||||
manifests[part] = read_manifests_if_cached(
|
||||
dataset_parts=["train"],
|
||||
output_dir=src_dir,
|
||||
prefix=f"icsi-{part}",
|
||||
suffix="jsonl.gz",
|
||||
)
|
||||
|
||||
for part in ["ihm-mix", "sdm"]:
|
||||
for split in ["train"]:
|
||||
logging.info(f"Processing {part} {split}")
|
||||
cuts = CutSet.from_manifests(
|
||||
**manifests[part][split]
|
||||
).compute_and_store_features_batch(
|
||||
extractor=extractor,
|
||||
storage_path=output_dir / f"icsi-{part}_{split}_feats",
|
||||
manifest_path=src_dir / f"cuts_icsi-{part}_{split}.jsonl.gz",
|
||||
batch_duration=5000,
|
||||
num_workers=4,
|
||||
storage_type=LilcomChunkyWriter,
|
||||
overwrite=True,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
|
||||
compute_fbank_icsi()
|
101
egs/ami/SURT/local/compute_fbank_ihm.py
Executable file
101
egs/ami/SURT/local/compute_fbank_ihm.py
Executable file
@ -0,0 +1,101 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2022 Johns Hopkins University (authors: Desh Raj)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
"""
|
||||
This file computes fbank features of the trimmed sub-segments which will be
|
||||
used for simulating the training mixtures.
|
||||
|
||||
The generated fbank features are saved in data/fbank.
|
||||
"""
|
||||
import logging
|
||||
import math
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import torch.multiprocessing
|
||||
import torchaudio
|
||||
from lhotse import CutSet, LilcomChunkyWriter, load_manifest
|
||||
from lhotse.audio import set_ffmpeg_torchaudio_info_enabled
|
||||
from lhotse.features.kaldifeat import (
|
||||
KaldifeatFbank,
|
||||
KaldifeatFbankConfig,
|
||||
KaldifeatFrameOptions,
|
||||
KaldifeatMelOptions,
|
||||
)
|
||||
from lhotse.recipes.utils import read_manifests_if_cached
|
||||
from tqdm import tqdm
|
||||
|
||||
# Torch's multithreaded behavior needs to be disabled or
|
||||
# it wastes a lot of CPU and slow things down.
|
||||
# Do this outside of main() in case it needs to take effect
|
||||
# even when we are not invoking the main (e.g. when spawning subprocesses).
|
||||
torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
torch.multiprocessing.set_sharing_strategy("file_system")
|
||||
torchaudio.set_audio_backend("soundfile")
|
||||
set_ffmpeg_torchaudio_info_enabled(False)
|
||||
|
||||
|
||||
def compute_fbank_ihm():
|
||||
src_dir = Path("data/manifests")
|
||||
output_dir = Path("data/fbank")
|
||||
|
||||
sampling_rate = 16000
|
||||
num_mel_bins = 80
|
||||
|
||||
extractor = KaldifeatFbank(
|
||||
KaldifeatFbankConfig(
|
||||
frame_opts=KaldifeatFrameOptions(sampling_rate=sampling_rate),
|
||||
mel_opts=KaldifeatMelOptions(num_bins=num_mel_bins),
|
||||
device="cuda",
|
||||
)
|
||||
)
|
||||
|
||||
logging.info("Reading manifests")
|
||||
manifests = {}
|
||||
for data in ["ami", "icsi"]:
|
||||
manifests[data] = read_manifests_if_cached(
|
||||
dataset_parts=["train"],
|
||||
output_dir=src_dir,
|
||||
types=["recordings", "supervisions"],
|
||||
prefix=f"{data}-ihm",
|
||||
suffix="jsonl.gz",
|
||||
)
|
||||
|
||||
logging.info("Computing features")
|
||||
for data in ["ami", "icsi"]:
|
||||
cs = CutSet.from_manifests(**manifests[data]["train"])
|
||||
cs = cs.trim_to_supervisions(keep_overlapping=False)
|
||||
cs = cs.normalize_loudness(target=-23.0, affix_id=False)
|
||||
cs = cs + cs.perturb_speed(0.9) + cs.perturb_speed(1.1)
|
||||
_ = cs.compute_and_store_features_batch(
|
||||
extractor=extractor,
|
||||
storage_path=output_dir / f"{data}-ihm_train_feats",
|
||||
manifest_path=src_dir / f"{data}-ihm_cuts_train.jsonl.gz",
|
||||
batch_duration=5000,
|
||||
num_workers=4,
|
||||
storage_type=LilcomChunkyWriter,
|
||||
overwrite=True,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
|
||||
compute_fbank_ihm()
|
146
egs/ami/SURT/local/prepare_ami_train_cuts.py
Executable file
146
egs/ami/SURT/local/prepare_ami_train_cuts.py
Executable file
@ -0,0 +1,146 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2022 Johns Hopkins University (authors: Desh Raj)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
"""
|
||||
This file creates AMI train segments.
|
||||
"""
|
||||
import logging
|
||||
import math
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import torch.multiprocessing
|
||||
from lhotse import LilcomChunkyWriter, load_manifest_lazy
|
||||
from lhotse.cut import Cut, CutSet
|
||||
from lhotse.utils import EPSILON, add_durations
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
def cut_into_windows(cuts: CutSet, duration: float):
|
||||
"""
|
||||
This function takes a CutSet and cuts each cut into windows of roughly
|
||||
`duration` seconds. By roughly, we mean that we try to adjust for the last supervision
|
||||
that exceeds the duration, or is shorter than the duration.
|
||||
"""
|
||||
res = []
|
||||
with tqdm() as pbar:
|
||||
for cut in cuts:
|
||||
pbar.update(1)
|
||||
sups = cut.index_supervisions()[cut.id]
|
||||
sr = cut.sampling_rate
|
||||
start = 0.0
|
||||
end = duration
|
||||
num_tries = 0
|
||||
while start < cut.duration and num_tries < 2:
|
||||
# Find the supervision that are cut by the window endpoint
|
||||
hitlist = [iv for iv in sups.at(end) if iv.begin < end]
|
||||
# If there are no supervisions, we are done
|
||||
if not hitlist:
|
||||
res.append(
|
||||
cut.truncate(
|
||||
offset=start,
|
||||
duration=add_durations(end, -start, sampling_rate=sr),
|
||||
keep_excessive_supervisions=False,
|
||||
)
|
||||
)
|
||||
# Update the start and end for the next window
|
||||
start = end
|
||||
end = add_durations(end, duration, sampling_rate=sr)
|
||||
else:
|
||||
# find ratio of durations cut by the window endpoint
|
||||
ratios = [
|
||||
add_durations(end, -iv.end, sampling_rate=sr) / iv.length()
|
||||
for iv in hitlist
|
||||
]
|
||||
# we retain the supervisions that have >50% of their duration
|
||||
# in the window, and discard the others
|
||||
retained = []
|
||||
discarded = []
|
||||
for iv, ratio in zip(hitlist, ratios):
|
||||
if ratio > 0.5:
|
||||
retained.append(iv)
|
||||
else:
|
||||
discarded.append(iv)
|
||||
cur_end = max(iv.end for iv in retained) if retained else end
|
||||
res.append(
|
||||
cut.truncate(
|
||||
offset=start,
|
||||
duration=add_durations(cur_end, -start, sampling_rate=sr),
|
||||
keep_excessive_supervisions=False,
|
||||
)
|
||||
)
|
||||
# For the next window, we start at the earliest discarded supervision
|
||||
next_start = min(iv.begin for iv in discarded) if discarded else end
|
||||
next_end = add_durations(next_start, duration, sampling_rate=sr)
|
||||
# It may happen that next_start is the same as start, in which case
|
||||
# we will advance the window anyway
|
||||
if next_start == start:
|
||||
logging.warning(
|
||||
f"Next start is the same as start: {next_start} == {start} for cut {cut.id}"
|
||||
)
|
||||
start = end + EPSILON
|
||||
end = add_durations(start, duration, sampling_rate=sr)
|
||||
num_tries += 1
|
||||
else:
|
||||
start = next_start
|
||||
end = next_end
|
||||
return CutSet.from_cuts(res)
|
||||
|
||||
|
||||
def prepare_train_cuts():
|
||||
src_dir = Path("data/manifests")
|
||||
|
||||
logging.info("Loading the manifests")
|
||||
train_cuts_ihm = load_manifest_lazy(
|
||||
src_dir / "cuts_ami-ihm-mix_train.jsonl.gz"
|
||||
).map(lambda c: c.with_id(f"{c.id}_ihm-mix"))
|
||||
train_cuts_sdm = load_manifest_lazy(src_dir / "cuts_ami-sdm_train.jsonl.gz").map(
|
||||
lambda c: c.with_id(f"{c.id}_sdm")
|
||||
)
|
||||
train_cuts_mdm = load_manifest_lazy(
|
||||
src_dir / "cuts_ami-mdm8-bf_train.jsonl.gz"
|
||||
).map(lambda c: c.with_id(f"{c.id}_mdm8-bf"))
|
||||
|
||||
# Combine all cuts into one CutSet
|
||||
train_cuts = train_cuts_ihm + train_cuts_sdm + train_cuts_mdm
|
||||
|
||||
train_cuts_1 = train_cuts.trim_to_supervision_groups(max_pause=0.5)
|
||||
train_cuts_2 = train_cuts.trim_to_supervision_groups(max_pause=0.0)
|
||||
|
||||
# Combine the two segmentations
|
||||
train_all = train_cuts_1 + train_cuts_2
|
||||
|
||||
# At this point, some of the cuts may be very long. We will cut them into windows of
|
||||
# roughly 30 seconds.
|
||||
logging.info("Cutting the segments into windows of 30 seconds")
|
||||
train_all_30 = cut_into_windows(train_all, duration=30.0)
|
||||
logging.info(f"Number of cuts after cutting into windows: {len(train_all_30)}")
|
||||
|
||||
# Show statistics
|
||||
train_all.describe(full=True)
|
||||
|
||||
# Save the cuts
|
||||
logging.info("Saving the cuts")
|
||||
train_all.to_file(src_dir / "cuts_train_ami.jsonl.gz")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
|
||||
prepare_train_cuts()
|
67
egs/ami/SURT/local/prepare_icsi_train_cuts.py
Executable file
67
egs/ami/SURT/local/prepare_icsi_train_cuts.py
Executable file
@ -0,0 +1,67 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2022 Johns Hopkins University (authors: Desh Raj)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
"""
|
||||
This file creates ICSI train segments.
|
||||
"""
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
from lhotse import load_manifest_lazy
|
||||
from prepare_ami_train_cuts import cut_into_windows
|
||||
|
||||
|
||||
def prepare_train_cuts():
|
||||
src_dir = Path("data/manifests")
|
||||
|
||||
logging.info("Loading the manifests")
|
||||
train_cuts_ihm = load_manifest_lazy(
|
||||
src_dir / "cuts_icsi-ihm-mix_train.jsonl.gz"
|
||||
).map(lambda c: c.with_id(f"{c.id}_ihm-mix"))
|
||||
train_cuts_sdm = load_manifest_lazy(src_dir / "cuts_icsi-sdm_train.jsonl.gz").map(
|
||||
lambda c: c.with_id(f"{c.id}_sdm")
|
||||
)
|
||||
|
||||
# Combine all cuts into one CutSet
|
||||
train_cuts = train_cuts_ihm + train_cuts_sdm
|
||||
|
||||
train_cuts_1 = train_cuts.trim_to_supervision_groups(max_pause=0.5)
|
||||
train_cuts_2 = train_cuts.trim_to_supervision_groups(max_pause=0.0)
|
||||
|
||||
# Combine the two segmentations
|
||||
train_all = train_cuts_1 + train_cuts_2
|
||||
|
||||
# At this point, some of the cuts may be very long. We will cut them into windows of
|
||||
# roughly 30 seconds.
|
||||
logging.info("Cutting the segments into windows of 30 seconds")
|
||||
train_all_30 = cut_into_windows(train_all, duration=30.0)
|
||||
logging.info(f"Number of cuts after cutting into windows: {len(train_all_30)}")
|
||||
|
||||
# Show statistics
|
||||
train_all.describe(full=True)
|
||||
|
||||
# Save the cuts
|
||||
logging.info("Saving the cuts")
|
||||
train_all.to_file(src_dir / "cuts_train_icsi.jsonl.gz")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
|
||||
prepare_train_cuts()
|
413
egs/ami/SURT/local/prepare_lang.py
Executable file
413
egs/ami/SURT/local/prepare_lang.py
Executable file
@ -0,0 +1,413 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
"""
|
||||
This script takes as input a lexicon file "data/lang_phone/lexicon.txt"
|
||||
consisting of words and tokens (i.e., phones) and does the following:
|
||||
|
||||
1. Add disambiguation symbols to the lexicon and generate lexicon_disambig.txt
|
||||
|
||||
2. Generate tokens.txt, the token table mapping a token to a unique integer.
|
||||
|
||||
3. Generate words.txt, the word table mapping a word to a unique integer.
|
||||
|
||||
4. Generate L.pt, in k2 format. It can be loaded by
|
||||
|
||||
d = torch.load("L.pt")
|
||||
lexicon = k2.Fsa.from_dict(d)
|
||||
|
||||
5. Generate L_disambig.pt, in k2 format.
|
||||
"""
|
||||
import argparse
|
||||
import math
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
import k2
|
||||
import torch
|
||||
|
||||
from icefall.lexicon import read_lexicon, write_lexicon
|
||||
from icefall.utils import str2bool
|
||||
|
||||
Lexicon = List[Tuple[str, List[str]]]
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--lang-dir",
|
||||
type=str,
|
||||
help="""Input and output directory.
|
||||
It should contain a file lexicon.txt.
|
||||
Generated files by this script are saved into this directory.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--debug",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="""True for debugging, which will generate
|
||||
a visualization of the lexicon FST.
|
||||
|
||||
Caution: If your lexicon contains hundreds of thousands
|
||||
of lines, please set it to False!
|
||||
""",
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def write_mapping(filename: str, sym2id: Dict[str, int]) -> None:
|
||||
"""Write a symbol to ID mapping to a file.
|
||||
|
||||
Note:
|
||||
No need to implement `read_mapping` as it can be done
|
||||
through :func:`k2.SymbolTable.from_file`.
|
||||
|
||||
Args:
|
||||
filename:
|
||||
Filename to save the mapping.
|
||||
sym2id:
|
||||
A dict mapping symbols to IDs.
|
||||
Returns:
|
||||
Return None.
|
||||
"""
|
||||
with open(filename, "w", encoding="utf-8") as f:
|
||||
for sym, i in sym2id.items():
|
||||
f.write(f"{sym} {i}\n")
|
||||
|
||||
|
||||
def get_tokens(lexicon: Lexicon) -> List[str]:
|
||||
"""Get tokens from a lexicon.
|
||||
|
||||
Args:
|
||||
lexicon:
|
||||
It is the return value of :func:`read_lexicon`.
|
||||
Returns:
|
||||
Return a list of unique tokens.
|
||||
"""
|
||||
ans = set()
|
||||
for _, tokens in lexicon:
|
||||
ans.update(tokens)
|
||||
sorted_ans = sorted(list(ans))
|
||||
return sorted_ans
|
||||
|
||||
|
||||
def get_words(lexicon: Lexicon) -> List[str]:
|
||||
"""Get words from a lexicon.
|
||||
|
||||
Args:
|
||||
lexicon:
|
||||
It is the return value of :func:`read_lexicon`.
|
||||
Returns:
|
||||
Return a list of unique words.
|
||||
"""
|
||||
ans = set()
|
||||
for word, _ in lexicon:
|
||||
ans.add(word)
|
||||
sorted_ans = sorted(list(ans))
|
||||
return sorted_ans
|
||||
|
||||
|
||||
def add_disambig_symbols(lexicon: Lexicon) -> Tuple[Lexicon, int]:
|
||||
"""It adds pseudo-token disambiguation symbols #1, #2 and so on
|
||||
at the ends of tokens to ensure that all pronunciations are different,
|
||||
and that none is a prefix of another.
|
||||
|
||||
See also add_lex_disambig.pl from kaldi.
|
||||
|
||||
Args:
|
||||
lexicon:
|
||||
It is returned by :func:`read_lexicon`.
|
||||
Returns:
|
||||
Return a tuple with two elements:
|
||||
|
||||
- The output lexicon with disambiguation symbols
|
||||
- The ID of the max disambiguation symbol that appears
|
||||
in the lexicon
|
||||
"""
|
||||
|
||||
# (1) Work out the count of each token-sequence in the
|
||||
# lexicon.
|
||||
count = defaultdict(int)
|
||||
for _, tokens in lexicon:
|
||||
count[" ".join(tokens)] += 1
|
||||
|
||||
# (2) For each left sub-sequence of each token-sequence, note down
|
||||
# that it exists (for identifying prefixes of longer strings).
|
||||
issubseq = defaultdict(int)
|
||||
for _, tokens in lexicon:
|
||||
tokens = tokens.copy()
|
||||
tokens.pop()
|
||||
while tokens:
|
||||
issubseq[" ".join(tokens)] = 1
|
||||
tokens.pop()
|
||||
|
||||
# (3) For each entry in the lexicon:
|
||||
# if the token sequence is unique and is not a
|
||||
# prefix of another word, no disambig symbol.
|
||||
# Else output #1, or #2, #3, ... if the same token-seq
|
||||
# has already been assigned a disambig symbol.
|
||||
ans = []
|
||||
|
||||
# We start with #1 since #0 has its own purpose
|
||||
first_allowed_disambig = 1
|
||||
max_disambig = first_allowed_disambig - 1
|
||||
last_used_disambig_symbol_of = defaultdict(int)
|
||||
|
||||
for word, tokens in lexicon:
|
||||
tokenseq = " ".join(tokens)
|
||||
assert tokenseq != ""
|
||||
if issubseq[tokenseq] == 0 and count[tokenseq] == 1:
|
||||
ans.append((word, tokens))
|
||||
continue
|
||||
|
||||
cur_disambig = last_used_disambig_symbol_of[tokenseq]
|
||||
if cur_disambig == 0:
|
||||
cur_disambig = first_allowed_disambig
|
||||
else:
|
||||
cur_disambig += 1
|
||||
|
||||
if cur_disambig > max_disambig:
|
||||
max_disambig = cur_disambig
|
||||
last_used_disambig_symbol_of[tokenseq] = cur_disambig
|
||||
tokenseq += f" #{cur_disambig}"
|
||||
ans.append((word, tokenseq.split()))
|
||||
return ans, max_disambig
|
||||
|
||||
|
||||
def generate_id_map(symbols: List[str]) -> Dict[str, int]:
|
||||
"""Generate ID maps, i.e., map a symbol to a unique ID.
|
||||
|
||||
Args:
|
||||
symbols:
|
||||
A list of unique symbols.
|
||||
Returns:
|
||||
A dict containing the mapping between symbols and IDs.
|
||||
"""
|
||||
return {sym: i for i, sym in enumerate(symbols)}
|
||||
|
||||
|
||||
def add_self_loops(
|
||||
arcs: List[List[Any]], disambig_token: int, disambig_word: int
|
||||
) -> List[List[Any]]:
|
||||
"""Adds self-loops to states of an FST to propagate disambiguation symbols
|
||||
through it. They are added on each state with non-epsilon output symbols
|
||||
on at least one arc out of the state.
|
||||
|
||||
See also fstaddselfloops.pl from Kaldi. One difference is that
|
||||
Kaldi uses OpenFst style FSTs and it has multiple final states.
|
||||
This function uses k2 style FSTs and it does not need to add self-loops
|
||||
to the final state.
|
||||
|
||||
The input label of a self-loop is `disambig_token`, while the output
|
||||
label is `disambig_word`.
|
||||
|
||||
Args:
|
||||
arcs:
|
||||
A list-of-list. The sublist contains
|
||||
`[src_state, dest_state, label, aux_label, score]`
|
||||
disambig_token:
|
||||
It is the token ID of the symbol `#0`.
|
||||
disambig_word:
|
||||
It is the word ID of the symbol `#0`.
|
||||
|
||||
Return:
|
||||
Return new `arcs` containing self-loops.
|
||||
"""
|
||||
states_needs_self_loops = set()
|
||||
for arc in arcs:
|
||||
src, dst, ilabel, olabel, score = arc
|
||||
if olabel != 0:
|
||||
states_needs_self_loops.add(src)
|
||||
|
||||
ans = []
|
||||
for s in states_needs_self_loops:
|
||||
ans.append([s, s, disambig_token, disambig_word, 0])
|
||||
|
||||
return arcs + ans
|
||||
|
||||
|
||||
def lexicon_to_fst(
|
||||
lexicon: Lexicon,
|
||||
token2id: Dict[str, int],
|
||||
word2id: Dict[str, int],
|
||||
sil_token: str = "SIL",
|
||||
sil_prob: float = 0.5,
|
||||
need_self_loops: bool = False,
|
||||
) -> k2.Fsa:
|
||||
"""Convert a lexicon to an FST (in k2 format) with optional silence at
|
||||
the beginning and end of each word.
|
||||
|
||||
Args:
|
||||
lexicon:
|
||||
The input lexicon. See also :func:`read_lexicon`
|
||||
token2id:
|
||||
A dict mapping tokens to IDs.
|
||||
word2id:
|
||||
A dict mapping words to IDs.
|
||||
sil_token:
|
||||
The silence token.
|
||||
sil_prob:
|
||||
The probability for adding a silence at the beginning and end
|
||||
of the word.
|
||||
need_self_loops:
|
||||
If True, add self-loop to states with non-epsilon output symbols
|
||||
on at least one arc out of the state. The input label for this
|
||||
self loop is `token2id["#0"]` and the output label is `word2id["#0"]`.
|
||||
Returns:
|
||||
Return an instance of `k2.Fsa` representing the given lexicon.
|
||||
"""
|
||||
assert sil_prob > 0.0 and sil_prob < 1.0
|
||||
# CAUTION: we use score, i.e, negative cost.
|
||||
sil_score = math.log(sil_prob)
|
||||
no_sil_score = math.log(1.0 - sil_prob)
|
||||
|
||||
start_state = 0
|
||||
loop_state = 1 # words enter and leave from here
|
||||
sil_state = 2 # words terminate here when followed by silence; this state
|
||||
# has a silence transition to loop_state.
|
||||
next_state = 3 # the next un-allocated state, will be incremented as we go.
|
||||
arcs = []
|
||||
|
||||
assert token2id["<eps>"] == 0
|
||||
assert word2id["<eps>"] == 0
|
||||
|
||||
eps = 0
|
||||
|
||||
sil_token = token2id[sil_token]
|
||||
|
||||
arcs.append([start_state, loop_state, eps, eps, no_sil_score])
|
||||
arcs.append([start_state, sil_state, eps, eps, sil_score])
|
||||
arcs.append([sil_state, loop_state, sil_token, eps, 0])
|
||||
|
||||
for word, tokens in lexicon:
|
||||
assert len(tokens) > 0, f"{word} has no pronunciations"
|
||||
cur_state = loop_state
|
||||
|
||||
word = word2id[word]
|
||||
tokens = [token2id[i] for i in tokens]
|
||||
|
||||
for i in range(len(tokens) - 1):
|
||||
w = word if i == 0 else eps
|
||||
arcs.append([cur_state, next_state, tokens[i], w, 0])
|
||||
|
||||
cur_state = next_state
|
||||
next_state += 1
|
||||
|
||||
# now for the last token of this word
|
||||
# It has two out-going arcs, one to the loop state,
|
||||
# the other one to the sil_state.
|
||||
i = len(tokens) - 1
|
||||
w = word if i == 0 else eps
|
||||
arcs.append([cur_state, loop_state, tokens[i], w, no_sil_score])
|
||||
arcs.append([cur_state, sil_state, tokens[i], w, sil_score])
|
||||
|
||||
if need_self_loops:
|
||||
disambig_token = token2id["#0"]
|
||||
disambig_word = word2id["#0"]
|
||||
arcs = add_self_loops(
|
||||
arcs,
|
||||
disambig_token=disambig_token,
|
||||
disambig_word=disambig_word,
|
||||
)
|
||||
|
||||
final_state = next_state
|
||||
arcs.append([loop_state, final_state, -1, -1, 0])
|
||||
arcs.append([final_state])
|
||||
|
||||
arcs = sorted(arcs, key=lambda arc: arc[0])
|
||||
arcs = [[str(i) for i in arc] for arc in arcs]
|
||||
arcs = [" ".join(arc) for arc in arcs]
|
||||
arcs = "\n".join(arcs)
|
||||
|
||||
fsa = k2.Fsa.from_str(arcs, acceptor=False)
|
||||
return fsa
|
||||
|
||||
|
||||
def main():
|
||||
args = get_args()
|
||||
lang_dir = Path(args.lang_dir)
|
||||
lexicon_filename = lang_dir / "lexicon.txt"
|
||||
sil_token = "SIL"
|
||||
sil_prob = 0.5
|
||||
|
||||
lexicon = read_lexicon(lexicon_filename)
|
||||
tokens = get_tokens(lexicon)
|
||||
words = get_words(lexicon)
|
||||
|
||||
lexicon_disambig, max_disambig = add_disambig_symbols(lexicon)
|
||||
|
||||
for i in range(max_disambig + 1):
|
||||
disambig = f"#{i}"
|
||||
assert disambig not in tokens
|
||||
tokens.append(f"#{i}")
|
||||
|
||||
assert "<eps>" not in tokens
|
||||
tokens = ["<eps>"] + tokens
|
||||
|
||||
assert "<eps>" not in words
|
||||
assert "#0" not in words
|
||||
assert "<s>" not in words
|
||||
assert "</s>" not in words
|
||||
|
||||
words = ["<eps>"] + words + ["#0", "<s>", "</s>"]
|
||||
|
||||
token2id = generate_id_map(tokens)
|
||||
word2id = generate_id_map(words)
|
||||
|
||||
write_mapping(lang_dir / "tokens.txt", token2id)
|
||||
write_mapping(lang_dir / "words.txt", word2id)
|
||||
write_lexicon(lang_dir / "lexicon_disambig.txt", lexicon_disambig)
|
||||
|
||||
L = lexicon_to_fst(
|
||||
lexicon,
|
||||
token2id=token2id,
|
||||
word2id=word2id,
|
||||
sil_token=sil_token,
|
||||
sil_prob=sil_prob,
|
||||
)
|
||||
|
||||
L_disambig = lexicon_to_fst(
|
||||
lexicon_disambig,
|
||||
token2id=token2id,
|
||||
word2id=word2id,
|
||||
sil_token=sil_token,
|
||||
sil_prob=sil_prob,
|
||||
need_self_loops=True,
|
||||
)
|
||||
torch.save(L.as_dict(), lang_dir / "L.pt")
|
||||
torch.save(L_disambig.as_dict(), lang_dir / "L_disambig.pt")
|
||||
|
||||
if args.debug:
|
||||
labels_sym = k2.SymbolTable.from_file(lang_dir / "tokens.txt")
|
||||
aux_labels_sym = k2.SymbolTable.from_file(lang_dir / "words.txt")
|
||||
|
||||
L.labels_sym = labels_sym
|
||||
L.aux_labels_sym = aux_labels_sym
|
||||
L.draw(f"{lang_dir / 'L.svg'}", title="L.pt")
|
||||
|
||||
L_disambig.labels_sym = labels_sym
|
||||
L_disambig.aux_labels_sym = aux_labels_sym
|
||||
L_disambig.draw(f"{lang_dir / 'L_disambig.svg'}", title="L_disambig.pt")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
266
egs/ami/SURT/local/prepare_lang_bpe.py
Executable file
266
egs/ami/SURT/local/prepare_lang_bpe.py
Executable file
@ -0,0 +1,266 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang)
|
||||
|
||||
"""
|
||||
|
||||
This script takes as input `lang_dir`, which should contain::
|
||||
|
||||
- lang_dir/bpe.model,
|
||||
- lang_dir/words.txt
|
||||
|
||||
and generates the following files in the directory `lang_dir`:
|
||||
|
||||
- lexicon.txt
|
||||
- lexicon_disambig.txt
|
||||
- L.pt
|
||||
- L_disambig.pt
|
||||
- tokens.txt
|
||||
"""
|
||||
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import k2
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
from prepare_lang import (
|
||||
Lexicon,
|
||||
add_disambig_symbols,
|
||||
add_self_loops,
|
||||
write_lexicon,
|
||||
write_mapping,
|
||||
)
|
||||
|
||||
from icefall.utils import str2bool
|
||||
|
||||
|
||||
def lexicon_to_fst_no_sil(
|
||||
lexicon: Lexicon,
|
||||
token2id: Dict[str, int],
|
||||
word2id: Dict[str, int],
|
||||
need_self_loops: bool = False,
|
||||
) -> k2.Fsa:
|
||||
"""Convert a lexicon to an FST (in k2 format).
|
||||
|
||||
Args:
|
||||
lexicon:
|
||||
The input lexicon. See also :func:`read_lexicon`
|
||||
token2id:
|
||||
A dict mapping tokens to IDs.
|
||||
word2id:
|
||||
A dict mapping words to IDs.
|
||||
need_self_loops:
|
||||
If True, add self-loop to states with non-epsilon output symbols
|
||||
on at least one arc out of the state. The input label for this
|
||||
self loop is `token2id["#0"]` and the output label is `word2id["#0"]`.
|
||||
Returns:
|
||||
Return an instance of `k2.Fsa` representing the given lexicon.
|
||||
"""
|
||||
loop_state = 0 # words enter and leave from here
|
||||
next_state = 1 # the next un-allocated state, will be incremented as we go
|
||||
|
||||
arcs = []
|
||||
|
||||
# The blank symbol <blk> is defined in local/train_bpe_model.py
|
||||
assert token2id["<blk>"] == 0
|
||||
assert word2id["<eps>"] == 0
|
||||
|
||||
eps = 0
|
||||
|
||||
for word, pieces in lexicon:
|
||||
assert len(pieces) > 0, f"{word} has no pronunciations"
|
||||
cur_state = loop_state
|
||||
|
||||
word = word2id[word]
|
||||
pieces = [token2id[i] for i in pieces]
|
||||
|
||||
for i in range(len(pieces) - 1):
|
||||
w = word if i == 0 else eps
|
||||
arcs.append([cur_state, next_state, pieces[i], w, 0])
|
||||
|
||||
cur_state = next_state
|
||||
next_state += 1
|
||||
|
||||
# now for the last piece of this word
|
||||
i = len(pieces) - 1
|
||||
w = word if i == 0 else eps
|
||||
arcs.append([cur_state, loop_state, pieces[i], w, 0])
|
||||
|
||||
if need_self_loops:
|
||||
disambig_token = token2id["#0"]
|
||||
disambig_word = word2id["#0"]
|
||||
arcs = add_self_loops(
|
||||
arcs,
|
||||
disambig_token=disambig_token,
|
||||
disambig_word=disambig_word,
|
||||
)
|
||||
|
||||
final_state = next_state
|
||||
arcs.append([loop_state, final_state, -1, -1, 0])
|
||||
arcs.append([final_state])
|
||||
|
||||
arcs = sorted(arcs, key=lambda arc: arc[0])
|
||||
arcs = [[str(i) for i in arc] for arc in arcs]
|
||||
arcs = [" ".join(arc) for arc in arcs]
|
||||
arcs = "\n".join(arcs)
|
||||
|
||||
fsa = k2.Fsa.from_str(arcs, acceptor=False)
|
||||
return fsa
|
||||
|
||||
|
||||
def generate_lexicon(
|
||||
model_file: str, words: List[str], oov: str
|
||||
) -> Tuple[Lexicon, Dict[str, int]]:
|
||||
"""Generate a lexicon from a BPE model.
|
||||
|
||||
Args:
|
||||
model_file:
|
||||
Path to a sentencepiece model.
|
||||
words:
|
||||
A list of strings representing words.
|
||||
oov:
|
||||
The out of vocabulary word in lexicon.
|
||||
Returns:
|
||||
Return a tuple with two elements:
|
||||
- A dict whose keys are words and values are the corresponding
|
||||
word pieces.
|
||||
- A dict representing the token symbol, mapping from tokens to IDs.
|
||||
"""
|
||||
sp = spm.SentencePieceProcessor()
|
||||
sp.load(str(model_file))
|
||||
|
||||
# Convert word to word piece IDs instead of word piece strings
|
||||
# to avoid OOV tokens.
|
||||
words_pieces_ids: List[List[int]] = sp.encode(words, out_type=int)
|
||||
|
||||
# Now convert word piece IDs back to word piece strings.
|
||||
words_pieces: List[List[str]] = [sp.id_to_piece(ids) for ids in words_pieces_ids]
|
||||
|
||||
lexicon = []
|
||||
for word, pieces in zip(words, words_pieces):
|
||||
lexicon.append((word, pieces))
|
||||
|
||||
lexicon.append((oov, ["▁", sp.id_to_piece(sp.unk_id())]))
|
||||
|
||||
token2id: Dict[str, int] = {sp.id_to_piece(i): i for i in range(sp.vocab_size())}
|
||||
|
||||
return lexicon, token2id
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--lang-dir",
|
||||
type=str,
|
||||
help="""Input and output directory.
|
||||
It should contain the bpe.model and words.txt
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--oov",
|
||||
type=str,
|
||||
default="<UNK>",
|
||||
help="The out of vocabulary word in lexicon.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--debug",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="""True for debugging, which will generate
|
||||
a visualization of the lexicon FST.
|
||||
|
||||
Caution: If your lexicon contains hundreds of thousands
|
||||
of lines, please set it to False!
|
||||
|
||||
See "test/test_bpe_lexicon.py" for usage.
|
||||
""",
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main():
|
||||
args = get_args()
|
||||
lang_dir = Path(args.lang_dir)
|
||||
model_file = lang_dir / "bpe.model"
|
||||
|
||||
word_sym_table = k2.SymbolTable.from_file(lang_dir / "words.txt")
|
||||
|
||||
words = word_sym_table.symbols
|
||||
|
||||
excluded = ["<eps>", "!SIL", "<SPOKEN_NOISE>", args.oov, "#0", "<s>", "</s>"]
|
||||
|
||||
for w in excluded:
|
||||
if w in words:
|
||||
words.remove(w)
|
||||
|
||||
lexicon, token_sym_table = generate_lexicon(model_file, words, args.oov)
|
||||
|
||||
lexicon_disambig, max_disambig = add_disambig_symbols(lexicon)
|
||||
|
||||
next_token_id = max(token_sym_table.values()) + 1
|
||||
for i in range(max_disambig + 1):
|
||||
disambig = f"#{i}"
|
||||
assert disambig not in token_sym_table
|
||||
token_sym_table[disambig] = next_token_id
|
||||
next_token_id += 1
|
||||
|
||||
word_sym_table.add("#0")
|
||||
word_sym_table.add("<s>")
|
||||
word_sym_table.add("</s>")
|
||||
|
||||
write_mapping(lang_dir / "tokens.txt", token_sym_table)
|
||||
|
||||
write_lexicon(lang_dir / "lexicon.txt", lexicon)
|
||||
write_lexicon(lang_dir / "lexicon_disambig.txt", lexicon_disambig)
|
||||
|
||||
L = lexicon_to_fst_no_sil(
|
||||
lexicon,
|
||||
token2id=token_sym_table,
|
||||
word2id=word_sym_table,
|
||||
)
|
||||
|
||||
L_disambig = lexicon_to_fst_no_sil(
|
||||
lexicon_disambig,
|
||||
token2id=token_sym_table,
|
||||
word2id=word_sym_table,
|
||||
need_self_loops=True,
|
||||
)
|
||||
torch.save(L.as_dict(), lang_dir / "L.pt")
|
||||
torch.save(L_disambig.as_dict(), lang_dir / "L_disambig.pt")
|
||||
|
||||
if args.debug:
|
||||
labels_sym = k2.SymbolTable.from_file(lang_dir / "tokens.txt")
|
||||
aux_labels_sym = k2.SymbolTable.from_file(lang_dir / "words.txt")
|
||||
|
||||
L.labels_sym = labels_sym
|
||||
L.aux_labels_sym = aux_labels_sym
|
||||
L.draw(f"{lang_dir / 'L.svg'}", title="L.pt")
|
||||
|
||||
L_disambig.labels_sym = labels_sym
|
||||
L_disambig.aux_labels_sym = aux_labels_sym
|
||||
L_disambig.draw(f"{lang_dir / 'L_disambig.svg'}", title="L_disambig.pt")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
100
egs/ami/SURT/local/train_bpe_model.py
Executable file
100
egs/ami/SURT/local/train_bpe_model.py
Executable file
@ -0,0 +1,100 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
# You can install sentencepiece via:
|
||||
#
|
||||
# pip install sentencepiece
|
||||
#
|
||||
# Due to an issue reported in
|
||||
# https://github.com/google/sentencepiece/pull/642#issuecomment-857972030
|
||||
#
|
||||
# Please install a version >=0.1.96
|
||||
|
||||
import argparse
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
import sentencepiece as spm
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--lang-dir",
|
||||
type=str,
|
||||
help="""Input and output directory.
|
||||
The generated bpe.model is saved to this directory.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--transcript",
|
||||
type=str,
|
||||
help="Training transcript.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--vocab-size",
|
||||
type=int,
|
||||
help="Vocabulary size for BPE training",
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main():
|
||||
args = get_args()
|
||||
vocab_size = args.vocab_size
|
||||
lang_dir = Path(args.lang_dir)
|
||||
|
||||
model_type = "unigram"
|
||||
|
||||
model_prefix = f"{lang_dir}/{model_type}_{vocab_size}"
|
||||
train_text = args.transcript
|
||||
character_coverage = 1.0
|
||||
input_sentence_size = 100000000
|
||||
|
||||
user_defined_symbols = ["<blk>", "<sos/eos>"]
|
||||
unk_id = len(user_defined_symbols)
|
||||
# Note: unk_id is fixed to 2.
|
||||
# If you change it, you should also change other
|
||||
# places that are using it.
|
||||
|
||||
model_file = Path(model_prefix + ".model")
|
||||
if not model_file.is_file():
|
||||
spm.SentencePieceTrainer.train(
|
||||
input=train_text,
|
||||
vocab_size=vocab_size,
|
||||
model_type=model_type,
|
||||
model_prefix=model_prefix,
|
||||
input_sentence_size=input_sentence_size,
|
||||
character_coverage=character_coverage,
|
||||
user_defined_symbols=user_defined_symbols,
|
||||
unk_id=unk_id,
|
||||
bos_id=-1,
|
||||
eos_id=-1,
|
||||
)
|
||||
else:
|
||||
print(f"{model_file} exists - skipping")
|
||||
return
|
||||
|
||||
shutil.copyfile(model_file, f"{lang_dir}/bpe.model")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
195
egs/ami/SURT/prepare.sh
Executable file
195
egs/ami/SURT/prepare.sh
Executable file
@ -0,0 +1,195 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
set -eou pipefail
|
||||
|
||||
stage=-1
|
||||
stop_stage=100
|
||||
|
||||
# We assume dl_dir (download dir) contains the following
|
||||
# directories and files. If not, they will be downloaded
|
||||
# by this script automatically.
|
||||
#
|
||||
# - $dl_dir/ami
|
||||
# You can find audio and transcripts for AMI in this path.
|
||||
#
|
||||
# - $dl_dir/icsi
|
||||
# You can find audio and transcripts for ICSI in this path.
|
||||
#
|
||||
# - $dl_dir/rirs_noises
|
||||
# This directory contains the RIRS_NOISES corpus downloaded from https://openslr.org/28/.
|
||||
#
|
||||
dl_dir=$PWD/download
|
||||
|
||||
. shared/parse_options.sh || exit 1
|
||||
|
||||
# All files generated by this script are saved in "data".
|
||||
# You can safely remove "data" and rerun this script to regenerate it.
|
||||
mkdir -p data
|
||||
vocab_size=500
|
||||
|
||||
log() {
|
||||
# This function is from espnet
|
||||
local fname=${BASH_SOURCE[1]##*/}
|
||||
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
|
||||
}
|
||||
|
||||
log "dl_dir: $dl_dir"
|
||||
|
||||
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
|
||||
log "Stage 0: Download data"
|
||||
|
||||
# If you have pre-downloaded it to /path/to/amicorpus,
|
||||
# you can create a symlink
|
||||
#
|
||||
# ln -sfv /path/to/amicorpus $dl_dir/amicorpus
|
||||
#
|
||||
if [ ! -d $dl_dir/amicorpus ]; then
|
||||
for mic in ihm ihm-mix sdm mdm8-bf; do
|
||||
lhotse download ami --mic $mic $dl_dir/amicorpus
|
||||
done
|
||||
fi
|
||||
|
||||
# If you have pre-downloaded it to /path/to/icsi,
|
||||
# you can create a symlink
|
||||
#
|
||||
# ln -sfv /path/to/icsi $dl_dir/icsi
|
||||
#
|
||||
if [ ! -d $dl_dir/icsi ]; then
|
||||
lhotse download icsi $dl_dir/icsi
|
||||
fi
|
||||
|
||||
# If you have pre-downloaded it to /path/to/rirs_noises,
|
||||
# you can create a symlink
|
||||
#
|
||||
# ln -sfv /path/to/rirs_noises $dl_dir/
|
||||
#
|
||||
if [ ! -d $dl_dir/rirs_noises ]; then
|
||||
lhotse download rirs_noises $dl_dir
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
|
||||
log "Stage 1: Prepare AMI manifests"
|
||||
# We assume that you have downloaded the AMI corpus
|
||||
# to $dl_dir/amicorpus. We perform text normalization for the transcripts.
|
||||
mkdir -p data/manifests
|
||||
for mic in ihm ihm-mix sdm mdm8-bf; do
|
||||
log "Preparing AMI manifest for $mic"
|
||||
lhotse prepare ami --mic $mic --max-words-per-segment 30 --merge-consecutive $dl_dir/amicorpus data/manifests/
|
||||
done
|
||||
fi
|
||||
|
||||
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
|
||||
log "Stage 2: Prepare ICSI manifests"
|
||||
# We assume that you have downloaded the ICSI corpus
|
||||
# to $dl_dir/icsi. We perform text normalization for the transcripts.
|
||||
mkdir -p data/manifests
|
||||
log "Preparing ICSI manifest"
|
||||
for mic in ihm ihm-mix sdm; do
|
||||
lhotse prepare icsi --mic $mic $dl_dir/icsi data/manifests/
|
||||
done
|
||||
fi
|
||||
|
||||
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
|
||||
log "Stage 3: Prepare RIRs"
|
||||
# We assume that you have downloaded the RIRS_NOISES corpus
|
||||
# to $dl_dir/rirs_noises
|
||||
lhotse prepare rir-noise -p real_rir -p iso_noise $dl_dir/rirs_noises data/manifests
|
||||
fi
|
||||
|
||||
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
|
||||
log "Stage 3: Extract features for AMI and ICSI recordings"
|
||||
python local/compute_fbank_ami.py
|
||||
python local/compute_fbank_icsi.py
|
||||
fi
|
||||
|
||||
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
|
||||
log "Stage 5: Create sources for simulating mixtures"
|
||||
# In the following script, we speed-perturb the IHM recordings and extract features.
|
||||
python local/compute_fbank_ihm.py
|
||||
lhotse combine data/manifests/ami-ihm_cuts_train.jsonl.gz \
|
||||
data/manifests/icsi-ihm_cuts_train.jsonl.gz - |\
|
||||
lhotse cut trim-to-alignments --type word --max-pause 0.5 - - |\
|
||||
lhotse filter 'duration<=12.0' - - |\
|
||||
shuf | gzip -c > data/manifests/ihm_cuts_train_trimmed.jsonl.gz
|
||||
fi
|
||||
|
||||
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
|
||||
log "Stage 6: Create training mixtures"
|
||||
lhotse workflows simulate-meetings \
|
||||
--method conversational \
|
||||
--same-spk-pause 0.5 \
|
||||
--diff-spk-pause 0.5 \
|
||||
--diff-spk-overlap 1.0 \
|
||||
--prob-diff-spk-overlap 0.8 \
|
||||
--num-meetings 200000 \
|
||||
--num-speakers-per-meeting 2,3 \
|
||||
--max-duration-per-speaker 15.0 \
|
||||
--max-utterances-per-speaker 3 \
|
||||
--seed 1234 \
|
||||
--num-jobs 2 \
|
||||
data/manifests/ihm_cuts_train_trimmed.jsonl.gz \
|
||||
data/manifests/ai-mix_cuts_clean.jsonl.gz
|
||||
|
||||
python local/compute_fbank_aimix.py
|
||||
|
||||
# Add source features to the manifest (will be used for masking loss)
|
||||
# This may take ~2 hours.
|
||||
python local/add_source_feats.py
|
||||
|
||||
# Combine clean and reverb
|
||||
cat <(gunzip -c data/manifests/cuts_train_clean_sources.jsonl.gz) \
|
||||
<(gunzip -c data/manifests/cuts_train_reverb_sources.jsonl.gz) |\
|
||||
shuf | gzip -c > data/manifests/cuts_train_comb_sources.jsonl.gz
|
||||
fi
|
||||
|
||||
if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then
|
||||
log "Stage 7: Create training mixtures from real sessions"
|
||||
python local/prepare_ami_train_cuts.py
|
||||
python local/prepare_icsi_train_cuts.py
|
||||
|
||||
# Combine AMI and ICSI
|
||||
cat <(gunzip -c data/manifests/cuts_train_ami.jsonl.gz) \
|
||||
<(gunzip -c data/manifests/cuts_train_icsi.jsonl.gz) |\
|
||||
shuf | gzip -c > data/manifests/cuts_train_ami_icsi.jsonl.gz
|
||||
fi
|
||||
|
||||
if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then
|
||||
log "Stage 8: Dump transcripts for BPE model training (using AMI and ICSI)."
|
||||
mkdir -p data/lm
|
||||
cat <(gunzip -c data/manifests/ami-sdm_supervisions_train.jsonl.gz | jq '.text' | sed 's:"::g') \
|
||||
<(gunzip -c data/manifests/icsi-sdm_supervisions_train.jsonl.gz | jq '.text' | sed 's:"::g') \
|
||||
> data/lm/transcript_words.txt
|
||||
fi
|
||||
|
||||
if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then
|
||||
log "Stage 9: Prepare BPE based lang (combining AMI and ICSI)"
|
||||
|
||||
lang_dir=data/lang_bpe_${vocab_size}
|
||||
mkdir -p $lang_dir
|
||||
|
||||
# Add special words to words.txt
|
||||
echo "<eps> 0" > $lang_dir/words.txt
|
||||
echo "!SIL 1" >> $lang_dir/words.txt
|
||||
echo "<UNK> 2" >> $lang_dir/words.txt
|
||||
|
||||
# Add regular words to words.txt
|
||||
cat data/lm/transcript_words.txt | grep -o -E '\w+' | sort -u | awk '{print $0,NR+2}' >> $lang_dir/words.txt
|
||||
|
||||
# Add remaining special word symbols expected by LM scripts.
|
||||
num_words=$(cat $lang_dir/words.txt | wc -l)
|
||||
echo "<s> ${num_words}" >> $lang_dir/words.txt
|
||||
num_words=$(cat $lang_dir/words.txt | wc -l)
|
||||
echo "</s> ${num_words}" >> $lang_dir/words.txt
|
||||
num_words=$(cat $lang_dir/words.txt | wc -l)
|
||||
echo "#0 ${num_words}" >> $lang_dir/words.txt
|
||||
|
||||
./local/train_bpe_model.py \
|
||||
--lang-dir $lang_dir \
|
||||
--vocab-size $vocab_size \
|
||||
--transcript data/lm/transcript_words.txt
|
||||
|
||||
if [ ! -f $lang_dir/L_disambig.pt ]; then
|
||||
./local/prepare_lang_bpe.py --lang-dir $lang_dir
|
||||
fi
|
||||
fi
|
1
egs/ami/SURT/shared
Symbolic link
1
egs/ami/SURT/shared
Symbolic link
@ -0,0 +1 @@
|
||||
../../../icefall/shared
|
Loading…
x
Reference in New Issue
Block a user