initial commit for SURT AMI recipe

This commit is contained in:
Desh Raj 2023-06-15 14:34:43 -04:00
parent d6b88aaa98
commit 14818f5dd8
12 changed files with 1741 additions and 0 deletions

View File

@ -0,0 +1,78 @@
#!/usr/bin/env python3
# Copyright 2022 Johns Hopkins University (authors: Desh Raj)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This file adds source features as temporal arrays to the mixture manifests.
It looks for manifests in the directory data/manifests.
"""
import logging
from pathlib import Path
import numpy as np
from lhotse import CutSet, LilcomChunkyWriter, load_manifest, load_manifest_lazy
from tqdm import tqdm
def add_source_feats():
src_dir = Path("data/manifests")
output_dir = Path("data/fbank")
logging.info("Reading mixed cuts")
mixed_cuts_clean = load_manifest_lazy(src_dir / "cuts_train_clean.jsonl.gz")
mixed_cuts_reverb = load_manifest_lazy(src_dir / "cuts_train_reverb.jsonl.gz")
logging.info("Reading source cuts")
source_cuts = load_manifest(src_dir / "ihm_cuts_train_trimmed.jsonl.gz")
logging.info("Adding source features to the mixed cuts")
pbar = tqdm(total=len(mixed_cuts_clean), desc="Adding source features")
with CutSet.open_writer(
src_dir / "cuts_train_clean_sources.jsonl.gz"
) as cut_writer_clean, CutSet.open_writer(
src_dir / "cuts_train_reverb_sources.jsonl.gz"
) as cut_writer_reverb, LilcomChunkyWriter(
output_dir / "feats_train_clean_sources"
) as source_feat_writer:
for cut_clean, cut_reverb in zip(mixed_cuts_clean, mixed_cuts_reverb):
assert cut_reverb.id == cut_clean.id + "_rvb"
source_feats = []
source_feat_offsets = []
cur_offset = 0
for sup in sorted(
cut_clean.supervisions, key=lambda s: (s.start, s.speaker)
):
source_cut = source_cuts[sup.id]
source_feats.append(source_cut.load_features())
source_feat_offsets.append(cur_offset)
cur_offset += source_cut.num_frames
cut_clean.source_feats = source_feat_writer.store_array(
cut_clean.id, np.concatenate(source_feats, axis=0)
)
cut_clean.source_feat_offsets = source_feat_offsets
cut_writer_clean.write(cut_clean)
# Also write the reverb cut
cut_reverb.source_feats = cut_clean.source_feats
cut_reverb.source_feat_offsets = cut_clean.source_feat_offsets
cut_writer_reverb.write(cut_reverb)
pbar.update(1)
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
add_source_feats()

View File

@ -0,0 +1,185 @@
#!/usr/bin/env python3
# Copyright 2022 Johns Hopkins University (authors: Desh Raj)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This file computes fbank features of the synthetically mixed AMI and ICSI
train set.
It looks for manifests in the directory data/manifests.
The generated fbank features are saved in data/fbank.
"""
import logging
import random
import warnings
from pathlib import Path
import torch
import torch.multiprocessing
import torchaudio
from lhotse import (
AudioSource,
LilcomChunkyWriter,
Recording,
load_manifest,
load_manifest_lazy,
)
from lhotse.audio import set_ffmpeg_torchaudio_info_enabled
from lhotse.cut import MixedCut, MixTrack, MultiCut
from lhotse.features.kaldifeat import (
KaldifeatFbank,
KaldifeatFbankConfig,
KaldifeatFrameOptions,
KaldifeatMelOptions,
)
from lhotse.utils import fix_random_seed, uuid4
from tqdm import tqdm
# Torch's multithreaded behavior needs to be disabled or
# it wastes a lot of CPU and slow things down.
# Do this outside of main() in case it needs to take effect
# even when we are not invoking the main (e.g. when spawning subprocesses).
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
torch.multiprocessing.set_sharing_strategy("file_system")
torchaudio.set_audio_backend("soundfile")
set_ffmpeg_torchaudio_info_enabled(False)
def compute_fbank_aimix():
src_dir = Path("data/manifests")
output_dir = Path("data/fbank")
sampling_rate = 16000
num_mel_bins = 80
extractor = KaldifeatFbank(
KaldifeatFbankConfig(
frame_opts=KaldifeatFrameOptions(sampling_rate=sampling_rate),
mel_opts=KaldifeatMelOptions(num_bins=num_mel_bins),
device="cuda",
)
)
logging.info("Reading manifests")
train_cuts = load_manifest_lazy(src_dir / "ai-mix_cuts_clean_full.jsonl.gz")
# only uses RIRs and noises from REVERB challenge
real_rirs = load_manifest(src_dir / "real-rir_recordings_all.jsonl.gz").filter(
lambda r: "RVB2014" in r.id
)
noises = load_manifest(src_dir / "iso-noise_recordings_all.jsonl.gz").filter(
lambda r: "RVB2014" in r.id
)
# Apply perturbation to the training cuts
logging.info("Applying perturbation to the training cuts")
train_cuts_rvb = train_cuts.map(
lambda c: augment(
c, perturb_snr=True, rirs=real_rirs, noises=noises, perturb_loudness=True
)
)
logging.info("Extracting fbank features for training cuts")
_ = train_cuts.compute_and_store_features_batch(
extractor=extractor,
storage_path=output_dir / "ai-mix_feats_clean",
manifest_path=src_dir / "cuts_train_clean.jsonl.gz",
batch_duration=5000,
num_workers=4,
storage_type=LilcomChunkyWriter,
overwrite=True,
)
with warnings.catch_warnings():
warnings.simplefilter("ignore")
_ = train_cuts_rvb.compute_and_store_features_batch(
extractor=extractor,
storage_path=output_dir / "ai-mix_feats_reverb",
manifest_path=src_dir / "cuts_train_reverb.jsonl.gz",
batch_duration=5000,
num_workers=4,
storage_type=LilcomChunkyWriter,
overwrite=True,
)
def augment(cut, perturb_snr=False, rirs=None, noises=None, perturb_loudness=False):
"""
Given a mixed cut, this function optionally applies the following augmentations:
- Perturbing the SNRs of the tracks (in range [-5, 5] dB)
- Reverberation using a randomly selected RIR
- Adding noise
- Perturbing the loudness (in range [-20, -25] dB)
"""
out_cut = cut.drop_features()
# Perturb the SNRs (optional)
if perturb_snr:
snrs = [random.uniform(-5, 5) for _ in range(len(cut.tracks))]
for i, (track, snr) in enumerate(zip(out_cut.tracks, snrs)):
if i == 0:
# Skip the first track since it is the reference
continue
track.snr = snr
# Reverberate the cut (optional)
if rirs is not None:
# Select an RIR at random
rir = random.choice(rirs)
# Select a channel at random
rir_channel = random.choice(list(range(rir.num_channels)))
# Reverberate the cut
out_cut = out_cut.reverb_rir(rir_recording=rir, rir_channels=[rir_channel])
# Add noise (optional)
if noises is not None:
# Select a noise recording at random
noise = random.choice(noises).to_cut()
if isinstance(noise, MultiCut):
noise = noise.to_mono()[0]
# Select an SNR at random
snr = random.uniform(10, 30)
# Repeat the noise to match the duration of the cut
noise = repeat_cut(noise, out_cut.duration)
out_cut = MixedCut(
id=out_cut.id,
tracks=[
MixTrack(cut=out_cut, type="MixedCut"),
MixTrack(cut=noise, type="DataCut", snr=snr),
],
)
# Perturb the loudness (optional)
if perturb_loudness:
target_loudness = random.uniform(-20, -25)
out_cut = out_cut.normalize_loudness(target_loudness, mix_first=True)
return out_cut
def repeat_cut(cut, duration):
while cut.duration < duration:
cut = cut.mix(cut, offset_other_by=cut.duration)
return cut.truncate(duration=duration)
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
fix_random_seed(42)
compute_fbank_aimix()

View File

@ -0,0 +1,94 @@
#!/usr/bin/env python3
# Copyright 2022 Johns Hopkins University (authors: Desh Raj)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This file computes fbank features of the AMI dataset.
We compute features for full recordings (i.e., without trimming to supervisions).
This way we can create arbitrary segmentations later.
The generated fbank features are saved in data/fbank.
"""
import logging
import math
from pathlib import Path
import torch
import torch.multiprocessing
from lhotse import CutSet, LilcomChunkyWriter
from lhotse.features.kaldifeat import (
KaldifeatFbank,
KaldifeatFbankConfig,
KaldifeatFrameOptions,
KaldifeatMelOptions,
)
from lhotse.recipes.utils import read_manifests_if_cached
# Torch's multithreaded behavior needs to be disabled or
# it wastes a lot of CPU and slow things down.
# Do this outside of main() in case it needs to take effect
# even when we are not invoking the main (e.g. when spawning subprocesses).
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
torch.multiprocessing.set_sharing_strategy("file_system")
def compute_fbank_ami():
src_dir = Path("data/manifests")
output_dir = Path("data/fbank")
sampling_rate = 16000
num_mel_bins = 80
extractor = KaldifeatFbank(
KaldifeatFbankConfig(
frame_opts=KaldifeatFrameOptions(sampling_rate=sampling_rate),
mel_opts=KaldifeatMelOptions(num_bins=num_mel_bins),
device="cuda",
)
)
logging.info("Reading manifests")
manifests = {}
for part in ["ihm-mix", "sdm", "mdm8-bf"]:
manifests[part] = read_manifests_if_cached(
dataset_parts=["train", "dev", "test"],
output_dir=src_dir,
prefix=f"ami-{part}",
suffix="jsonl.gz",
)
for part in ["ihm-mix", "sdm", "mdm8-bf"]:
for split in ["train", "dev", "test"]:
logging.info(f"Processing {part} {split}")
cuts = CutSet.from_manifests(
**manifests[part][split]
).compute_and_store_features_batch(
extractor=extractor,
storage_path=output_dir / f"ami-{part}_{split}_feats",
manifest_path=src_dir / f"cuts_ami-{part}_{split}.jsonl.gz",
batch_duration=5000,
num_workers=4,
storage_type=LilcomChunkyWriter,
)
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
compute_fbank_ami()

View File

@ -0,0 +1,95 @@
#!/usr/bin/env python3
# Copyright 2022 Johns Hopkins University (authors: Desh Raj)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This file computes fbank features of the ICSI dataset.
We compute features for full recordings (i.e., without trimming to supervisions).
This way we can create arbitrary segmentations later.
The generated fbank features are saved in data/fbank.
"""
import logging
import math
from pathlib import Path
import torch
import torch.multiprocessing
from lhotse import CutSet, LilcomChunkyWriter
from lhotse.features.kaldifeat import (
KaldifeatFbank,
KaldifeatFbankConfig,
KaldifeatFrameOptions,
KaldifeatMelOptions,
)
from lhotse.recipes.utils import read_manifests_if_cached
# Torch's multithreaded behavior needs to be disabled or
# it wastes a lot of CPU and slow things down.
# Do this outside of main() in case it needs to take effect
# even when we are not invoking the main (e.g. when spawning subprocesses).
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
torch.multiprocessing.set_sharing_strategy("file_system")
def compute_fbank_icsi():
src_dir = Path("data/manifests")
output_dir = Path("data/fbank")
sampling_rate = 16000
num_mel_bins = 80
extractor = KaldifeatFbank(
KaldifeatFbankConfig(
frame_opts=KaldifeatFrameOptions(sampling_rate=sampling_rate),
mel_opts=KaldifeatMelOptions(num_bins=num_mel_bins),
device="cuda",
)
)
logging.info("Reading manifests")
manifests = {}
for part in ["ihm-mix", "sdm"]:
manifests[part] = read_manifests_if_cached(
dataset_parts=["train"],
output_dir=src_dir,
prefix=f"icsi-{part}",
suffix="jsonl.gz",
)
for part in ["ihm-mix", "sdm"]:
for split in ["train"]:
logging.info(f"Processing {part} {split}")
cuts = CutSet.from_manifests(
**manifests[part][split]
).compute_and_store_features_batch(
extractor=extractor,
storage_path=output_dir / f"icsi-{part}_{split}_feats",
manifest_path=src_dir / f"cuts_icsi-{part}_{split}.jsonl.gz",
batch_duration=5000,
num_workers=4,
storage_type=LilcomChunkyWriter,
overwrite=True,
)
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
compute_fbank_icsi()

View File

@ -0,0 +1,101 @@
#!/usr/bin/env python3
# Copyright 2022 Johns Hopkins University (authors: Desh Raj)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This file computes fbank features of the trimmed sub-segments which will be
used for simulating the training mixtures.
The generated fbank features are saved in data/fbank.
"""
import logging
import math
from pathlib import Path
import torch
import torch.multiprocessing
import torchaudio
from lhotse import CutSet, LilcomChunkyWriter, load_manifest
from lhotse.audio import set_ffmpeg_torchaudio_info_enabled
from lhotse.features.kaldifeat import (
KaldifeatFbank,
KaldifeatFbankConfig,
KaldifeatFrameOptions,
KaldifeatMelOptions,
)
from lhotse.recipes.utils import read_manifests_if_cached
from tqdm import tqdm
# Torch's multithreaded behavior needs to be disabled or
# it wastes a lot of CPU and slow things down.
# Do this outside of main() in case it needs to take effect
# even when we are not invoking the main (e.g. when spawning subprocesses).
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
torch.multiprocessing.set_sharing_strategy("file_system")
torchaudio.set_audio_backend("soundfile")
set_ffmpeg_torchaudio_info_enabled(False)
def compute_fbank_ihm():
src_dir = Path("data/manifests")
output_dir = Path("data/fbank")
sampling_rate = 16000
num_mel_bins = 80
extractor = KaldifeatFbank(
KaldifeatFbankConfig(
frame_opts=KaldifeatFrameOptions(sampling_rate=sampling_rate),
mel_opts=KaldifeatMelOptions(num_bins=num_mel_bins),
device="cuda",
)
)
logging.info("Reading manifests")
manifests = {}
for data in ["ami", "icsi"]:
manifests[data] = read_manifests_if_cached(
dataset_parts=["train"],
output_dir=src_dir,
types=["recordings", "supervisions"],
prefix=f"{data}-ihm",
suffix="jsonl.gz",
)
logging.info("Computing features")
for data in ["ami", "icsi"]:
cs = CutSet.from_manifests(**manifests[data]["train"])
cs = cs.trim_to_supervisions(keep_overlapping=False)
cs = cs.normalize_loudness(target=-23.0, affix_id=False)
cs = cs + cs.perturb_speed(0.9) + cs.perturb_speed(1.1)
_ = cs.compute_and_store_features_batch(
extractor=extractor,
storage_path=output_dir / f"{data}-ihm_train_feats",
manifest_path=src_dir / f"{data}-ihm_cuts_train.jsonl.gz",
batch_duration=5000,
num_workers=4,
storage_type=LilcomChunkyWriter,
overwrite=True,
)
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
compute_fbank_ihm()

View File

@ -0,0 +1,146 @@
#!/usr/bin/env python3
# Copyright 2022 Johns Hopkins University (authors: Desh Raj)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This file creates AMI train segments.
"""
import logging
import math
from pathlib import Path
import torch
import torch.multiprocessing
from lhotse import LilcomChunkyWriter, load_manifest_lazy
from lhotse.cut import Cut, CutSet
from lhotse.utils import EPSILON, add_durations
from tqdm import tqdm
def cut_into_windows(cuts: CutSet, duration: float):
"""
This function takes a CutSet and cuts each cut into windows of roughly
`duration` seconds. By roughly, we mean that we try to adjust for the last supervision
that exceeds the duration, or is shorter than the duration.
"""
res = []
with tqdm() as pbar:
for cut in cuts:
pbar.update(1)
sups = cut.index_supervisions()[cut.id]
sr = cut.sampling_rate
start = 0.0
end = duration
num_tries = 0
while start < cut.duration and num_tries < 2:
# Find the supervision that are cut by the window endpoint
hitlist = [iv for iv in sups.at(end) if iv.begin < end]
# If there are no supervisions, we are done
if not hitlist:
res.append(
cut.truncate(
offset=start,
duration=add_durations(end, -start, sampling_rate=sr),
keep_excessive_supervisions=False,
)
)
# Update the start and end for the next window
start = end
end = add_durations(end, duration, sampling_rate=sr)
else:
# find ratio of durations cut by the window endpoint
ratios = [
add_durations(end, -iv.end, sampling_rate=sr) / iv.length()
for iv in hitlist
]
# we retain the supervisions that have >50% of their duration
# in the window, and discard the others
retained = []
discarded = []
for iv, ratio in zip(hitlist, ratios):
if ratio > 0.5:
retained.append(iv)
else:
discarded.append(iv)
cur_end = max(iv.end for iv in retained) if retained else end
res.append(
cut.truncate(
offset=start,
duration=add_durations(cur_end, -start, sampling_rate=sr),
keep_excessive_supervisions=False,
)
)
# For the next window, we start at the earliest discarded supervision
next_start = min(iv.begin for iv in discarded) if discarded else end
next_end = add_durations(next_start, duration, sampling_rate=sr)
# It may happen that next_start is the same as start, in which case
# we will advance the window anyway
if next_start == start:
logging.warning(
f"Next start is the same as start: {next_start} == {start} for cut {cut.id}"
)
start = end + EPSILON
end = add_durations(start, duration, sampling_rate=sr)
num_tries += 1
else:
start = next_start
end = next_end
return CutSet.from_cuts(res)
def prepare_train_cuts():
src_dir = Path("data/manifests")
logging.info("Loading the manifests")
train_cuts_ihm = load_manifest_lazy(
src_dir / "cuts_ami-ihm-mix_train.jsonl.gz"
).map(lambda c: c.with_id(f"{c.id}_ihm-mix"))
train_cuts_sdm = load_manifest_lazy(src_dir / "cuts_ami-sdm_train.jsonl.gz").map(
lambda c: c.with_id(f"{c.id}_sdm")
)
train_cuts_mdm = load_manifest_lazy(
src_dir / "cuts_ami-mdm8-bf_train.jsonl.gz"
).map(lambda c: c.with_id(f"{c.id}_mdm8-bf"))
# Combine all cuts into one CutSet
train_cuts = train_cuts_ihm + train_cuts_sdm + train_cuts_mdm
train_cuts_1 = train_cuts.trim_to_supervision_groups(max_pause=0.5)
train_cuts_2 = train_cuts.trim_to_supervision_groups(max_pause=0.0)
# Combine the two segmentations
train_all = train_cuts_1 + train_cuts_2
# At this point, some of the cuts may be very long. We will cut them into windows of
# roughly 30 seconds.
logging.info("Cutting the segments into windows of 30 seconds")
train_all_30 = cut_into_windows(train_all, duration=30.0)
logging.info(f"Number of cuts after cutting into windows: {len(train_all_30)}")
# Show statistics
train_all.describe(full=True)
# Save the cuts
logging.info("Saving the cuts")
train_all.to_file(src_dir / "cuts_train_ami.jsonl.gz")
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
prepare_train_cuts()

View File

@ -0,0 +1,67 @@
#!/usr/bin/env python3
# Copyright 2022 Johns Hopkins University (authors: Desh Raj)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This file creates ICSI train segments.
"""
import logging
from pathlib import Path
from lhotse import load_manifest_lazy
from prepare_ami_train_cuts import cut_into_windows
def prepare_train_cuts():
src_dir = Path("data/manifests")
logging.info("Loading the manifests")
train_cuts_ihm = load_manifest_lazy(
src_dir / "cuts_icsi-ihm-mix_train.jsonl.gz"
).map(lambda c: c.with_id(f"{c.id}_ihm-mix"))
train_cuts_sdm = load_manifest_lazy(src_dir / "cuts_icsi-sdm_train.jsonl.gz").map(
lambda c: c.with_id(f"{c.id}_sdm")
)
# Combine all cuts into one CutSet
train_cuts = train_cuts_ihm + train_cuts_sdm
train_cuts_1 = train_cuts.trim_to_supervision_groups(max_pause=0.5)
train_cuts_2 = train_cuts.trim_to_supervision_groups(max_pause=0.0)
# Combine the two segmentations
train_all = train_cuts_1 + train_cuts_2
# At this point, some of the cuts may be very long. We will cut them into windows of
# roughly 30 seconds.
logging.info("Cutting the segments into windows of 30 seconds")
train_all_30 = cut_into_windows(train_all, duration=30.0)
logging.info(f"Number of cuts after cutting into windows: {len(train_all_30)}")
# Show statistics
train_all.describe(full=True)
# Save the cuts
logging.info("Saving the cuts")
train_all.to_file(src_dir / "cuts_train_icsi.jsonl.gz")
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
prepare_train_cuts()

View File

@ -0,0 +1,413 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script takes as input a lexicon file "data/lang_phone/lexicon.txt"
consisting of words and tokens (i.e., phones) and does the following:
1. Add disambiguation symbols to the lexicon and generate lexicon_disambig.txt
2. Generate tokens.txt, the token table mapping a token to a unique integer.
3. Generate words.txt, the word table mapping a word to a unique integer.
4. Generate L.pt, in k2 format. It can be loaded by
d = torch.load("L.pt")
lexicon = k2.Fsa.from_dict(d)
5. Generate L_disambig.pt, in k2 format.
"""
import argparse
import math
from collections import defaultdict
from pathlib import Path
from typing import Any, Dict, List, Tuple
import k2
import torch
from icefall.lexicon import read_lexicon, write_lexicon
from icefall.utils import str2bool
Lexicon = List[Tuple[str, List[str]]]
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--lang-dir",
type=str,
help="""Input and output directory.
It should contain a file lexicon.txt.
Generated files by this script are saved into this directory.
""",
)
parser.add_argument(
"--debug",
type=str2bool,
default=False,
help="""True for debugging, which will generate
a visualization of the lexicon FST.
Caution: If your lexicon contains hundreds of thousands
of lines, please set it to False!
""",
)
return parser.parse_args()
def write_mapping(filename: str, sym2id: Dict[str, int]) -> None:
"""Write a symbol to ID mapping to a file.
Note:
No need to implement `read_mapping` as it can be done
through :func:`k2.SymbolTable.from_file`.
Args:
filename:
Filename to save the mapping.
sym2id:
A dict mapping symbols to IDs.
Returns:
Return None.
"""
with open(filename, "w", encoding="utf-8") as f:
for sym, i in sym2id.items():
f.write(f"{sym} {i}\n")
def get_tokens(lexicon: Lexicon) -> List[str]:
"""Get tokens from a lexicon.
Args:
lexicon:
It is the return value of :func:`read_lexicon`.
Returns:
Return a list of unique tokens.
"""
ans = set()
for _, tokens in lexicon:
ans.update(tokens)
sorted_ans = sorted(list(ans))
return sorted_ans
def get_words(lexicon: Lexicon) -> List[str]:
"""Get words from a lexicon.
Args:
lexicon:
It is the return value of :func:`read_lexicon`.
Returns:
Return a list of unique words.
"""
ans = set()
for word, _ in lexicon:
ans.add(word)
sorted_ans = sorted(list(ans))
return sorted_ans
def add_disambig_symbols(lexicon: Lexicon) -> Tuple[Lexicon, int]:
"""It adds pseudo-token disambiguation symbols #1, #2 and so on
at the ends of tokens to ensure that all pronunciations are different,
and that none is a prefix of another.
See also add_lex_disambig.pl from kaldi.
Args:
lexicon:
It is returned by :func:`read_lexicon`.
Returns:
Return a tuple with two elements:
- The output lexicon with disambiguation symbols
- The ID of the max disambiguation symbol that appears
in the lexicon
"""
# (1) Work out the count of each token-sequence in the
# lexicon.
count = defaultdict(int)
for _, tokens in lexicon:
count[" ".join(tokens)] += 1
# (2) For each left sub-sequence of each token-sequence, note down
# that it exists (for identifying prefixes of longer strings).
issubseq = defaultdict(int)
for _, tokens in lexicon:
tokens = tokens.copy()
tokens.pop()
while tokens:
issubseq[" ".join(tokens)] = 1
tokens.pop()
# (3) For each entry in the lexicon:
# if the token sequence is unique and is not a
# prefix of another word, no disambig symbol.
# Else output #1, or #2, #3, ... if the same token-seq
# has already been assigned a disambig symbol.
ans = []
# We start with #1 since #0 has its own purpose
first_allowed_disambig = 1
max_disambig = first_allowed_disambig - 1
last_used_disambig_symbol_of = defaultdict(int)
for word, tokens in lexicon:
tokenseq = " ".join(tokens)
assert tokenseq != ""
if issubseq[tokenseq] == 0 and count[tokenseq] == 1:
ans.append((word, tokens))
continue
cur_disambig = last_used_disambig_symbol_of[tokenseq]
if cur_disambig == 0:
cur_disambig = first_allowed_disambig
else:
cur_disambig += 1
if cur_disambig > max_disambig:
max_disambig = cur_disambig
last_used_disambig_symbol_of[tokenseq] = cur_disambig
tokenseq += f" #{cur_disambig}"
ans.append((word, tokenseq.split()))
return ans, max_disambig
def generate_id_map(symbols: List[str]) -> Dict[str, int]:
"""Generate ID maps, i.e., map a symbol to a unique ID.
Args:
symbols:
A list of unique symbols.
Returns:
A dict containing the mapping between symbols and IDs.
"""
return {sym: i for i, sym in enumerate(symbols)}
def add_self_loops(
arcs: List[List[Any]], disambig_token: int, disambig_word: int
) -> List[List[Any]]:
"""Adds self-loops to states of an FST to propagate disambiguation symbols
through it. They are added on each state with non-epsilon output symbols
on at least one arc out of the state.
See also fstaddselfloops.pl from Kaldi. One difference is that
Kaldi uses OpenFst style FSTs and it has multiple final states.
This function uses k2 style FSTs and it does not need to add self-loops
to the final state.
The input label of a self-loop is `disambig_token`, while the output
label is `disambig_word`.
Args:
arcs:
A list-of-list. The sublist contains
`[src_state, dest_state, label, aux_label, score]`
disambig_token:
It is the token ID of the symbol `#0`.
disambig_word:
It is the word ID of the symbol `#0`.
Return:
Return new `arcs` containing self-loops.
"""
states_needs_self_loops = set()
for arc in arcs:
src, dst, ilabel, olabel, score = arc
if olabel != 0:
states_needs_self_loops.add(src)
ans = []
for s in states_needs_self_loops:
ans.append([s, s, disambig_token, disambig_word, 0])
return arcs + ans
def lexicon_to_fst(
lexicon: Lexicon,
token2id: Dict[str, int],
word2id: Dict[str, int],
sil_token: str = "SIL",
sil_prob: float = 0.5,
need_self_loops: bool = False,
) -> k2.Fsa:
"""Convert a lexicon to an FST (in k2 format) with optional silence at
the beginning and end of each word.
Args:
lexicon:
The input lexicon. See also :func:`read_lexicon`
token2id:
A dict mapping tokens to IDs.
word2id:
A dict mapping words to IDs.
sil_token:
The silence token.
sil_prob:
The probability for adding a silence at the beginning and end
of the word.
need_self_loops:
If True, add self-loop to states with non-epsilon output symbols
on at least one arc out of the state. The input label for this
self loop is `token2id["#0"]` and the output label is `word2id["#0"]`.
Returns:
Return an instance of `k2.Fsa` representing the given lexicon.
"""
assert sil_prob > 0.0 and sil_prob < 1.0
# CAUTION: we use score, i.e, negative cost.
sil_score = math.log(sil_prob)
no_sil_score = math.log(1.0 - sil_prob)
start_state = 0
loop_state = 1 # words enter and leave from here
sil_state = 2 # words terminate here when followed by silence; this state
# has a silence transition to loop_state.
next_state = 3 # the next un-allocated state, will be incremented as we go.
arcs = []
assert token2id["<eps>"] == 0
assert word2id["<eps>"] == 0
eps = 0
sil_token = token2id[sil_token]
arcs.append([start_state, loop_state, eps, eps, no_sil_score])
arcs.append([start_state, sil_state, eps, eps, sil_score])
arcs.append([sil_state, loop_state, sil_token, eps, 0])
for word, tokens in lexicon:
assert len(tokens) > 0, f"{word} has no pronunciations"
cur_state = loop_state
word = word2id[word]
tokens = [token2id[i] for i in tokens]
for i in range(len(tokens) - 1):
w = word if i == 0 else eps
arcs.append([cur_state, next_state, tokens[i], w, 0])
cur_state = next_state
next_state += 1
# now for the last token of this word
# It has two out-going arcs, one to the loop state,
# the other one to the sil_state.
i = len(tokens) - 1
w = word if i == 0 else eps
arcs.append([cur_state, loop_state, tokens[i], w, no_sil_score])
arcs.append([cur_state, sil_state, tokens[i], w, sil_score])
if need_self_loops:
disambig_token = token2id["#0"]
disambig_word = word2id["#0"]
arcs = add_self_loops(
arcs,
disambig_token=disambig_token,
disambig_word=disambig_word,
)
final_state = next_state
arcs.append([loop_state, final_state, -1, -1, 0])
arcs.append([final_state])
arcs = sorted(arcs, key=lambda arc: arc[0])
arcs = [[str(i) for i in arc] for arc in arcs]
arcs = [" ".join(arc) for arc in arcs]
arcs = "\n".join(arcs)
fsa = k2.Fsa.from_str(arcs, acceptor=False)
return fsa
def main():
args = get_args()
lang_dir = Path(args.lang_dir)
lexicon_filename = lang_dir / "lexicon.txt"
sil_token = "SIL"
sil_prob = 0.5
lexicon = read_lexicon(lexicon_filename)
tokens = get_tokens(lexicon)
words = get_words(lexicon)
lexicon_disambig, max_disambig = add_disambig_symbols(lexicon)
for i in range(max_disambig + 1):
disambig = f"#{i}"
assert disambig not in tokens
tokens.append(f"#{i}")
assert "<eps>" not in tokens
tokens = ["<eps>"] + tokens
assert "<eps>" not in words
assert "#0" not in words
assert "<s>" not in words
assert "</s>" not in words
words = ["<eps>"] + words + ["#0", "<s>", "</s>"]
token2id = generate_id_map(tokens)
word2id = generate_id_map(words)
write_mapping(lang_dir / "tokens.txt", token2id)
write_mapping(lang_dir / "words.txt", word2id)
write_lexicon(lang_dir / "lexicon_disambig.txt", lexicon_disambig)
L = lexicon_to_fst(
lexicon,
token2id=token2id,
word2id=word2id,
sil_token=sil_token,
sil_prob=sil_prob,
)
L_disambig = lexicon_to_fst(
lexicon_disambig,
token2id=token2id,
word2id=word2id,
sil_token=sil_token,
sil_prob=sil_prob,
need_self_loops=True,
)
torch.save(L.as_dict(), lang_dir / "L.pt")
torch.save(L_disambig.as_dict(), lang_dir / "L_disambig.pt")
if args.debug:
labels_sym = k2.SymbolTable.from_file(lang_dir / "tokens.txt")
aux_labels_sym = k2.SymbolTable.from_file(lang_dir / "words.txt")
L.labels_sym = labels_sym
L.aux_labels_sym = aux_labels_sym
L.draw(f"{lang_dir / 'L.svg'}", title="L.pt")
L_disambig.labels_sym = labels_sym
L_disambig.aux_labels_sym = aux_labels_sym
L_disambig.draw(f"{lang_dir / 'L_disambig.svg'}", title="L_disambig.pt")
if __name__ == "__main__":
main()

View File

@ -0,0 +1,266 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang)
"""
This script takes as input `lang_dir`, which should contain::
- lang_dir/bpe.model,
- lang_dir/words.txt
and generates the following files in the directory `lang_dir`:
- lexicon.txt
- lexicon_disambig.txt
- L.pt
- L_disambig.pt
- tokens.txt
"""
import argparse
from pathlib import Path
from typing import Dict, List, Tuple
import k2
import sentencepiece as spm
import torch
from prepare_lang import (
Lexicon,
add_disambig_symbols,
add_self_loops,
write_lexicon,
write_mapping,
)
from icefall.utils import str2bool
def lexicon_to_fst_no_sil(
lexicon: Lexicon,
token2id: Dict[str, int],
word2id: Dict[str, int],
need_self_loops: bool = False,
) -> k2.Fsa:
"""Convert a lexicon to an FST (in k2 format).
Args:
lexicon:
The input lexicon. See also :func:`read_lexicon`
token2id:
A dict mapping tokens to IDs.
word2id:
A dict mapping words to IDs.
need_self_loops:
If True, add self-loop to states with non-epsilon output symbols
on at least one arc out of the state. The input label for this
self loop is `token2id["#0"]` and the output label is `word2id["#0"]`.
Returns:
Return an instance of `k2.Fsa` representing the given lexicon.
"""
loop_state = 0 # words enter and leave from here
next_state = 1 # the next un-allocated state, will be incremented as we go
arcs = []
# The blank symbol <blk> is defined in local/train_bpe_model.py
assert token2id["<blk>"] == 0
assert word2id["<eps>"] == 0
eps = 0
for word, pieces in lexicon:
assert len(pieces) > 0, f"{word} has no pronunciations"
cur_state = loop_state
word = word2id[word]
pieces = [token2id[i] for i in pieces]
for i in range(len(pieces) - 1):
w = word if i == 0 else eps
arcs.append([cur_state, next_state, pieces[i], w, 0])
cur_state = next_state
next_state += 1
# now for the last piece of this word
i = len(pieces) - 1
w = word if i == 0 else eps
arcs.append([cur_state, loop_state, pieces[i], w, 0])
if need_self_loops:
disambig_token = token2id["#0"]
disambig_word = word2id["#0"]
arcs = add_self_loops(
arcs,
disambig_token=disambig_token,
disambig_word=disambig_word,
)
final_state = next_state
arcs.append([loop_state, final_state, -1, -1, 0])
arcs.append([final_state])
arcs = sorted(arcs, key=lambda arc: arc[0])
arcs = [[str(i) for i in arc] for arc in arcs]
arcs = [" ".join(arc) for arc in arcs]
arcs = "\n".join(arcs)
fsa = k2.Fsa.from_str(arcs, acceptor=False)
return fsa
def generate_lexicon(
model_file: str, words: List[str], oov: str
) -> Tuple[Lexicon, Dict[str, int]]:
"""Generate a lexicon from a BPE model.
Args:
model_file:
Path to a sentencepiece model.
words:
A list of strings representing words.
oov:
The out of vocabulary word in lexicon.
Returns:
Return a tuple with two elements:
- A dict whose keys are words and values are the corresponding
word pieces.
- A dict representing the token symbol, mapping from tokens to IDs.
"""
sp = spm.SentencePieceProcessor()
sp.load(str(model_file))
# Convert word to word piece IDs instead of word piece strings
# to avoid OOV tokens.
words_pieces_ids: List[List[int]] = sp.encode(words, out_type=int)
# Now convert word piece IDs back to word piece strings.
words_pieces: List[List[str]] = [sp.id_to_piece(ids) for ids in words_pieces_ids]
lexicon = []
for word, pieces in zip(words, words_pieces):
lexicon.append((word, pieces))
lexicon.append((oov, ["", sp.id_to_piece(sp.unk_id())]))
token2id: Dict[str, int] = {sp.id_to_piece(i): i for i in range(sp.vocab_size())}
return lexicon, token2id
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--lang-dir",
type=str,
help="""Input and output directory.
It should contain the bpe.model and words.txt
""",
)
parser.add_argument(
"--oov",
type=str,
default="<UNK>",
help="The out of vocabulary word in lexicon.",
)
parser.add_argument(
"--debug",
type=str2bool,
default=False,
help="""True for debugging, which will generate
a visualization of the lexicon FST.
Caution: If your lexicon contains hundreds of thousands
of lines, please set it to False!
See "test/test_bpe_lexicon.py" for usage.
""",
)
return parser.parse_args()
def main():
args = get_args()
lang_dir = Path(args.lang_dir)
model_file = lang_dir / "bpe.model"
word_sym_table = k2.SymbolTable.from_file(lang_dir / "words.txt")
words = word_sym_table.symbols
excluded = ["<eps>", "!SIL", "<SPOKEN_NOISE>", args.oov, "#0", "<s>", "</s>"]
for w in excluded:
if w in words:
words.remove(w)
lexicon, token_sym_table = generate_lexicon(model_file, words, args.oov)
lexicon_disambig, max_disambig = add_disambig_symbols(lexicon)
next_token_id = max(token_sym_table.values()) + 1
for i in range(max_disambig + 1):
disambig = f"#{i}"
assert disambig not in token_sym_table
token_sym_table[disambig] = next_token_id
next_token_id += 1
word_sym_table.add("#0")
word_sym_table.add("<s>")
word_sym_table.add("</s>")
write_mapping(lang_dir / "tokens.txt", token_sym_table)
write_lexicon(lang_dir / "lexicon.txt", lexicon)
write_lexicon(lang_dir / "lexicon_disambig.txt", lexicon_disambig)
L = lexicon_to_fst_no_sil(
lexicon,
token2id=token_sym_table,
word2id=word_sym_table,
)
L_disambig = lexicon_to_fst_no_sil(
lexicon_disambig,
token2id=token_sym_table,
word2id=word_sym_table,
need_self_loops=True,
)
torch.save(L.as_dict(), lang_dir / "L.pt")
torch.save(L_disambig.as_dict(), lang_dir / "L_disambig.pt")
if args.debug:
labels_sym = k2.SymbolTable.from_file(lang_dir / "tokens.txt")
aux_labels_sym = k2.SymbolTable.from_file(lang_dir / "words.txt")
L.labels_sym = labels_sym
L.aux_labels_sym = aux_labels_sym
L.draw(f"{lang_dir / 'L.svg'}", title="L.pt")
L_disambig.labels_sym = labels_sym
L_disambig.aux_labels_sym = aux_labels_sym
L_disambig.draw(f"{lang_dir / 'L_disambig.svg'}", title="L_disambig.pt")
if __name__ == "__main__":
main()

View File

@ -0,0 +1,100 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# You can install sentencepiece via:
#
# pip install sentencepiece
#
# Due to an issue reported in
# https://github.com/google/sentencepiece/pull/642#issuecomment-857972030
#
# Please install a version >=0.1.96
import argparse
import shutil
from pathlib import Path
import sentencepiece as spm
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--lang-dir",
type=str,
help="""Input and output directory.
The generated bpe.model is saved to this directory.
""",
)
parser.add_argument(
"--transcript",
type=str,
help="Training transcript.",
)
parser.add_argument(
"--vocab-size",
type=int,
help="Vocabulary size for BPE training",
)
return parser.parse_args()
def main():
args = get_args()
vocab_size = args.vocab_size
lang_dir = Path(args.lang_dir)
model_type = "unigram"
model_prefix = f"{lang_dir}/{model_type}_{vocab_size}"
train_text = args.transcript
character_coverage = 1.0
input_sentence_size = 100000000
user_defined_symbols = ["<blk>", "<sos/eos>"]
unk_id = len(user_defined_symbols)
# Note: unk_id is fixed to 2.
# If you change it, you should also change other
# places that are using it.
model_file = Path(model_prefix + ".model")
if not model_file.is_file():
spm.SentencePieceTrainer.train(
input=train_text,
vocab_size=vocab_size,
model_type=model_type,
model_prefix=model_prefix,
input_sentence_size=input_sentence_size,
character_coverage=character_coverage,
user_defined_symbols=user_defined_symbols,
unk_id=unk_id,
bos_id=-1,
eos_id=-1,
)
else:
print(f"{model_file} exists - skipping")
return
shutil.copyfile(model_file, f"{lang_dir}/bpe.model")
if __name__ == "__main__":
main()

195
egs/ami/SURT/prepare.sh Executable file
View File

@ -0,0 +1,195 @@
#!/usr/bin/env bash
set -eou pipefail
stage=-1
stop_stage=100
# We assume dl_dir (download dir) contains the following
# directories and files. If not, they will be downloaded
# by this script automatically.
#
# - $dl_dir/ami
# You can find audio and transcripts for AMI in this path.
#
# - $dl_dir/icsi
# You can find audio and transcripts for ICSI in this path.
#
# - $dl_dir/rirs_noises
# This directory contains the RIRS_NOISES corpus downloaded from https://openslr.org/28/.
#
dl_dir=$PWD/download
. shared/parse_options.sh || exit 1
# All files generated by this script are saved in "data".
# You can safely remove "data" and rerun this script to regenerate it.
mkdir -p data
vocab_size=500
log() {
# This function is from espnet
local fname=${BASH_SOURCE[1]##*/}
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
}
log "dl_dir: $dl_dir"
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
log "Stage 0: Download data"
# If you have pre-downloaded it to /path/to/amicorpus,
# you can create a symlink
#
# ln -sfv /path/to/amicorpus $dl_dir/amicorpus
#
if [ ! -d $dl_dir/amicorpus ]; then
for mic in ihm ihm-mix sdm mdm8-bf; do
lhotse download ami --mic $mic $dl_dir/amicorpus
done
fi
# If you have pre-downloaded it to /path/to/icsi,
# you can create a symlink
#
# ln -sfv /path/to/icsi $dl_dir/icsi
#
if [ ! -d $dl_dir/icsi ]; then
lhotse download icsi $dl_dir/icsi
fi
# If you have pre-downloaded it to /path/to/rirs_noises,
# you can create a symlink
#
# ln -sfv /path/to/rirs_noises $dl_dir/
#
if [ ! -d $dl_dir/rirs_noises ]; then
lhotse download rirs_noises $dl_dir
fi
fi
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
log "Stage 1: Prepare AMI manifests"
# We assume that you have downloaded the AMI corpus
# to $dl_dir/amicorpus. We perform text normalization for the transcripts.
mkdir -p data/manifests
for mic in ihm ihm-mix sdm mdm8-bf; do
log "Preparing AMI manifest for $mic"
lhotse prepare ami --mic $mic --max-words-per-segment 30 --merge-consecutive $dl_dir/amicorpus data/manifests/
done
fi
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
log "Stage 2: Prepare ICSI manifests"
# We assume that you have downloaded the ICSI corpus
# to $dl_dir/icsi. We perform text normalization for the transcripts.
mkdir -p data/manifests
log "Preparing ICSI manifest"
for mic in ihm ihm-mix sdm; do
lhotse prepare icsi --mic $mic $dl_dir/icsi data/manifests/
done
fi
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
log "Stage 3: Prepare RIRs"
# We assume that you have downloaded the RIRS_NOISES corpus
# to $dl_dir/rirs_noises
lhotse prepare rir-noise -p real_rir -p iso_noise $dl_dir/rirs_noises data/manifests
fi
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
log "Stage 3: Extract features for AMI and ICSI recordings"
python local/compute_fbank_ami.py
python local/compute_fbank_icsi.py
fi
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
log "Stage 5: Create sources for simulating mixtures"
# In the following script, we speed-perturb the IHM recordings and extract features.
python local/compute_fbank_ihm.py
lhotse combine data/manifests/ami-ihm_cuts_train.jsonl.gz \
data/manifests/icsi-ihm_cuts_train.jsonl.gz - |\
lhotse cut trim-to-alignments --type word --max-pause 0.5 - - |\
lhotse filter 'duration<=12.0' - - |\
shuf | gzip -c > data/manifests/ihm_cuts_train_trimmed.jsonl.gz
fi
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
log "Stage 6: Create training mixtures"
lhotse workflows simulate-meetings \
--method conversational \
--same-spk-pause 0.5 \
--diff-spk-pause 0.5 \
--diff-spk-overlap 1.0 \
--prob-diff-spk-overlap 0.8 \
--num-meetings 200000 \
--num-speakers-per-meeting 2,3 \
--max-duration-per-speaker 15.0 \
--max-utterances-per-speaker 3 \
--seed 1234 \
--num-jobs 2 \
data/manifests/ihm_cuts_train_trimmed.jsonl.gz \
data/manifests/ai-mix_cuts_clean.jsonl.gz
python local/compute_fbank_aimix.py
# Add source features to the manifest (will be used for masking loss)
# This may take ~2 hours.
python local/add_source_feats.py
# Combine clean and reverb
cat <(gunzip -c data/manifests/cuts_train_clean_sources.jsonl.gz) \
<(gunzip -c data/manifests/cuts_train_reverb_sources.jsonl.gz) |\
shuf | gzip -c > data/manifests/cuts_train_comb_sources.jsonl.gz
fi
if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then
log "Stage 7: Create training mixtures from real sessions"
python local/prepare_ami_train_cuts.py
python local/prepare_icsi_train_cuts.py
# Combine AMI and ICSI
cat <(gunzip -c data/manifests/cuts_train_ami.jsonl.gz) \
<(gunzip -c data/manifests/cuts_train_icsi.jsonl.gz) |\
shuf | gzip -c > data/manifests/cuts_train_ami_icsi.jsonl.gz
fi
if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then
log "Stage 8: Dump transcripts for BPE model training (using AMI and ICSI)."
mkdir -p data/lm
cat <(gunzip -c data/manifests/ami-sdm_supervisions_train.jsonl.gz | jq '.text' | sed 's:"::g') \
<(gunzip -c data/manifests/icsi-sdm_supervisions_train.jsonl.gz | jq '.text' | sed 's:"::g') \
> data/lm/transcript_words.txt
fi
if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then
log "Stage 9: Prepare BPE based lang (combining AMI and ICSI)"
lang_dir=data/lang_bpe_${vocab_size}
mkdir -p $lang_dir
# Add special words to words.txt
echo "<eps> 0" > $lang_dir/words.txt
echo "!SIL 1" >> $lang_dir/words.txt
echo "<UNK> 2" >> $lang_dir/words.txt
# Add regular words to words.txt
cat data/lm/transcript_words.txt | grep -o -E '\w+' | sort -u | awk '{print $0,NR+2}' >> $lang_dir/words.txt
# Add remaining special word symbols expected by LM scripts.
num_words=$(cat $lang_dir/words.txt | wc -l)
echo "<s> ${num_words}" >> $lang_dir/words.txt
num_words=$(cat $lang_dir/words.txt | wc -l)
echo "</s> ${num_words}" >> $lang_dir/words.txt
num_words=$(cat $lang_dir/words.txt | wc -l)
echo "#0 ${num_words}" >> $lang_dir/words.txt
./local/train_bpe_model.py \
--lang-dir $lang_dir \
--vocab-size $vocab_size \
--transcript data/lm/transcript_words.txt
if [ ! -f $lang_dir/L_disambig.pt ]; then
./local/prepare_lang_bpe.py --lang-dir $lang_dir
fi
fi

1
egs/ami/SURT/shared Symbolic link
View File

@ -0,0 +1 @@
../../../icefall/shared