mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
initial commit for SURT AMI recipe
This commit is contained in:
parent
d6b88aaa98
commit
14818f5dd8
78
egs/ami/SURT/local/add_source_feats.py
Executable file
78
egs/ami/SURT/local/add_source_feats.py
Executable file
@ -0,0 +1,78 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# Copyright 2022 Johns Hopkins University (authors: Desh Raj)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
This file adds source features as temporal arrays to the mixture manifests.
|
||||||
|
It looks for manifests in the directory data/manifests.
|
||||||
|
"""
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from lhotse import CutSet, LilcomChunkyWriter, load_manifest, load_manifest_lazy
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
|
||||||
|
def add_source_feats():
|
||||||
|
src_dir = Path("data/manifests")
|
||||||
|
output_dir = Path("data/fbank")
|
||||||
|
|
||||||
|
logging.info("Reading mixed cuts")
|
||||||
|
mixed_cuts_clean = load_manifest_lazy(src_dir / "cuts_train_clean.jsonl.gz")
|
||||||
|
mixed_cuts_reverb = load_manifest_lazy(src_dir / "cuts_train_reverb.jsonl.gz")
|
||||||
|
|
||||||
|
logging.info("Reading source cuts")
|
||||||
|
source_cuts = load_manifest(src_dir / "ihm_cuts_train_trimmed.jsonl.gz")
|
||||||
|
|
||||||
|
logging.info("Adding source features to the mixed cuts")
|
||||||
|
pbar = tqdm(total=len(mixed_cuts_clean), desc="Adding source features")
|
||||||
|
with CutSet.open_writer(
|
||||||
|
src_dir / "cuts_train_clean_sources.jsonl.gz"
|
||||||
|
) as cut_writer_clean, CutSet.open_writer(
|
||||||
|
src_dir / "cuts_train_reverb_sources.jsonl.gz"
|
||||||
|
) as cut_writer_reverb, LilcomChunkyWriter(
|
||||||
|
output_dir / "feats_train_clean_sources"
|
||||||
|
) as source_feat_writer:
|
||||||
|
for cut_clean, cut_reverb in zip(mixed_cuts_clean, mixed_cuts_reverb):
|
||||||
|
assert cut_reverb.id == cut_clean.id + "_rvb"
|
||||||
|
source_feats = []
|
||||||
|
source_feat_offsets = []
|
||||||
|
cur_offset = 0
|
||||||
|
for sup in sorted(
|
||||||
|
cut_clean.supervisions, key=lambda s: (s.start, s.speaker)
|
||||||
|
):
|
||||||
|
source_cut = source_cuts[sup.id]
|
||||||
|
source_feats.append(source_cut.load_features())
|
||||||
|
source_feat_offsets.append(cur_offset)
|
||||||
|
cur_offset += source_cut.num_frames
|
||||||
|
cut_clean.source_feats = source_feat_writer.store_array(
|
||||||
|
cut_clean.id, np.concatenate(source_feats, axis=0)
|
||||||
|
)
|
||||||
|
cut_clean.source_feat_offsets = source_feat_offsets
|
||||||
|
cut_writer_clean.write(cut_clean)
|
||||||
|
# Also write the reverb cut
|
||||||
|
cut_reverb.source_feats = cut_clean.source_feats
|
||||||
|
cut_reverb.source_feat_offsets = cut_clean.source_feat_offsets
|
||||||
|
cut_writer_reverb.write(cut_reverb)
|
||||||
|
pbar.update(1)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||||
|
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||||
|
add_source_feats()
|
||||||
185
egs/ami/SURT/local/compute_fbank_aimix.py
Executable file
185
egs/ami/SURT/local/compute_fbank_aimix.py
Executable file
@ -0,0 +1,185 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# Copyright 2022 Johns Hopkins University (authors: Desh Raj)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
This file computes fbank features of the synthetically mixed AMI and ICSI
|
||||||
|
train set.
|
||||||
|
It looks for manifests in the directory data/manifests.
|
||||||
|
|
||||||
|
The generated fbank features are saved in data/fbank.
|
||||||
|
"""
|
||||||
|
import logging
|
||||||
|
import random
|
||||||
|
import warnings
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.multiprocessing
|
||||||
|
import torchaudio
|
||||||
|
from lhotse import (
|
||||||
|
AudioSource,
|
||||||
|
LilcomChunkyWriter,
|
||||||
|
Recording,
|
||||||
|
load_manifest,
|
||||||
|
load_manifest_lazy,
|
||||||
|
)
|
||||||
|
from lhotse.audio import set_ffmpeg_torchaudio_info_enabled
|
||||||
|
from lhotse.cut import MixedCut, MixTrack, MultiCut
|
||||||
|
from lhotse.features.kaldifeat import (
|
||||||
|
KaldifeatFbank,
|
||||||
|
KaldifeatFbankConfig,
|
||||||
|
KaldifeatFrameOptions,
|
||||||
|
KaldifeatMelOptions,
|
||||||
|
)
|
||||||
|
from lhotse.utils import fix_random_seed, uuid4
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
# Torch's multithreaded behavior needs to be disabled or
|
||||||
|
# it wastes a lot of CPU and slow things down.
|
||||||
|
# Do this outside of main() in case it needs to take effect
|
||||||
|
# even when we are not invoking the main (e.g. when spawning subprocesses).
|
||||||
|
torch.set_num_threads(1)
|
||||||
|
torch.set_num_interop_threads(1)
|
||||||
|
torch.multiprocessing.set_sharing_strategy("file_system")
|
||||||
|
torchaudio.set_audio_backend("soundfile")
|
||||||
|
set_ffmpeg_torchaudio_info_enabled(False)
|
||||||
|
|
||||||
|
|
||||||
|
def compute_fbank_aimix():
|
||||||
|
src_dir = Path("data/manifests")
|
||||||
|
output_dir = Path("data/fbank")
|
||||||
|
|
||||||
|
sampling_rate = 16000
|
||||||
|
num_mel_bins = 80
|
||||||
|
|
||||||
|
extractor = KaldifeatFbank(
|
||||||
|
KaldifeatFbankConfig(
|
||||||
|
frame_opts=KaldifeatFrameOptions(sampling_rate=sampling_rate),
|
||||||
|
mel_opts=KaldifeatMelOptions(num_bins=num_mel_bins),
|
||||||
|
device="cuda",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
logging.info("Reading manifests")
|
||||||
|
train_cuts = load_manifest_lazy(src_dir / "ai-mix_cuts_clean_full.jsonl.gz")
|
||||||
|
|
||||||
|
# only uses RIRs and noises from REVERB challenge
|
||||||
|
real_rirs = load_manifest(src_dir / "real-rir_recordings_all.jsonl.gz").filter(
|
||||||
|
lambda r: "RVB2014" in r.id
|
||||||
|
)
|
||||||
|
noises = load_manifest(src_dir / "iso-noise_recordings_all.jsonl.gz").filter(
|
||||||
|
lambda r: "RVB2014" in r.id
|
||||||
|
)
|
||||||
|
|
||||||
|
# Apply perturbation to the training cuts
|
||||||
|
logging.info("Applying perturbation to the training cuts")
|
||||||
|
train_cuts_rvb = train_cuts.map(
|
||||||
|
lambda c: augment(
|
||||||
|
c, perturb_snr=True, rirs=real_rirs, noises=noises, perturb_loudness=True
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
logging.info("Extracting fbank features for training cuts")
|
||||||
|
_ = train_cuts.compute_and_store_features_batch(
|
||||||
|
extractor=extractor,
|
||||||
|
storage_path=output_dir / "ai-mix_feats_clean",
|
||||||
|
manifest_path=src_dir / "cuts_train_clean.jsonl.gz",
|
||||||
|
batch_duration=5000,
|
||||||
|
num_workers=4,
|
||||||
|
storage_type=LilcomChunkyWriter,
|
||||||
|
overwrite=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
with warnings.catch_warnings():
|
||||||
|
warnings.simplefilter("ignore")
|
||||||
|
_ = train_cuts_rvb.compute_and_store_features_batch(
|
||||||
|
extractor=extractor,
|
||||||
|
storage_path=output_dir / "ai-mix_feats_reverb",
|
||||||
|
manifest_path=src_dir / "cuts_train_reverb.jsonl.gz",
|
||||||
|
batch_duration=5000,
|
||||||
|
num_workers=4,
|
||||||
|
storage_type=LilcomChunkyWriter,
|
||||||
|
overwrite=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def augment(cut, perturb_snr=False, rirs=None, noises=None, perturb_loudness=False):
|
||||||
|
"""
|
||||||
|
Given a mixed cut, this function optionally applies the following augmentations:
|
||||||
|
- Perturbing the SNRs of the tracks (in range [-5, 5] dB)
|
||||||
|
- Reverberation using a randomly selected RIR
|
||||||
|
- Adding noise
|
||||||
|
- Perturbing the loudness (in range [-20, -25] dB)
|
||||||
|
"""
|
||||||
|
out_cut = cut.drop_features()
|
||||||
|
|
||||||
|
# Perturb the SNRs (optional)
|
||||||
|
if perturb_snr:
|
||||||
|
snrs = [random.uniform(-5, 5) for _ in range(len(cut.tracks))]
|
||||||
|
for i, (track, snr) in enumerate(zip(out_cut.tracks, snrs)):
|
||||||
|
if i == 0:
|
||||||
|
# Skip the first track since it is the reference
|
||||||
|
continue
|
||||||
|
track.snr = snr
|
||||||
|
|
||||||
|
# Reverberate the cut (optional)
|
||||||
|
if rirs is not None:
|
||||||
|
# Select an RIR at random
|
||||||
|
rir = random.choice(rirs)
|
||||||
|
# Select a channel at random
|
||||||
|
rir_channel = random.choice(list(range(rir.num_channels)))
|
||||||
|
# Reverberate the cut
|
||||||
|
out_cut = out_cut.reverb_rir(rir_recording=rir, rir_channels=[rir_channel])
|
||||||
|
|
||||||
|
# Add noise (optional)
|
||||||
|
if noises is not None:
|
||||||
|
# Select a noise recording at random
|
||||||
|
noise = random.choice(noises).to_cut()
|
||||||
|
if isinstance(noise, MultiCut):
|
||||||
|
noise = noise.to_mono()[0]
|
||||||
|
# Select an SNR at random
|
||||||
|
snr = random.uniform(10, 30)
|
||||||
|
# Repeat the noise to match the duration of the cut
|
||||||
|
noise = repeat_cut(noise, out_cut.duration)
|
||||||
|
out_cut = MixedCut(
|
||||||
|
id=out_cut.id,
|
||||||
|
tracks=[
|
||||||
|
MixTrack(cut=out_cut, type="MixedCut"),
|
||||||
|
MixTrack(cut=noise, type="DataCut", snr=snr),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Perturb the loudness (optional)
|
||||||
|
if perturb_loudness:
|
||||||
|
target_loudness = random.uniform(-20, -25)
|
||||||
|
out_cut = out_cut.normalize_loudness(target_loudness, mix_first=True)
|
||||||
|
return out_cut
|
||||||
|
|
||||||
|
|
||||||
|
def repeat_cut(cut, duration):
|
||||||
|
while cut.duration < duration:
|
||||||
|
cut = cut.mix(cut, offset_other_by=cut.duration)
|
||||||
|
return cut.truncate(duration=duration)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||||
|
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||||
|
|
||||||
|
fix_random_seed(42)
|
||||||
|
compute_fbank_aimix()
|
||||||
94
egs/ami/SURT/local/compute_fbank_ami.py
Executable file
94
egs/ami/SURT/local/compute_fbank_ami.py
Executable file
@ -0,0 +1,94 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# Copyright 2022 Johns Hopkins University (authors: Desh Raj)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
This file computes fbank features of the AMI dataset.
|
||||||
|
We compute features for full recordings (i.e., without trimming to supervisions).
|
||||||
|
This way we can create arbitrary segmentations later.
|
||||||
|
|
||||||
|
The generated fbank features are saved in data/fbank.
|
||||||
|
"""
|
||||||
|
import logging
|
||||||
|
import math
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.multiprocessing
|
||||||
|
from lhotse import CutSet, LilcomChunkyWriter
|
||||||
|
from lhotse.features.kaldifeat import (
|
||||||
|
KaldifeatFbank,
|
||||||
|
KaldifeatFbankConfig,
|
||||||
|
KaldifeatFrameOptions,
|
||||||
|
KaldifeatMelOptions,
|
||||||
|
)
|
||||||
|
from lhotse.recipes.utils import read_manifests_if_cached
|
||||||
|
|
||||||
|
# Torch's multithreaded behavior needs to be disabled or
|
||||||
|
# it wastes a lot of CPU and slow things down.
|
||||||
|
# Do this outside of main() in case it needs to take effect
|
||||||
|
# even when we are not invoking the main (e.g. when spawning subprocesses).
|
||||||
|
torch.set_num_threads(1)
|
||||||
|
torch.set_num_interop_threads(1)
|
||||||
|
torch.multiprocessing.set_sharing_strategy("file_system")
|
||||||
|
|
||||||
|
|
||||||
|
def compute_fbank_ami():
|
||||||
|
src_dir = Path("data/manifests")
|
||||||
|
output_dir = Path("data/fbank")
|
||||||
|
|
||||||
|
sampling_rate = 16000
|
||||||
|
num_mel_bins = 80
|
||||||
|
|
||||||
|
extractor = KaldifeatFbank(
|
||||||
|
KaldifeatFbankConfig(
|
||||||
|
frame_opts=KaldifeatFrameOptions(sampling_rate=sampling_rate),
|
||||||
|
mel_opts=KaldifeatMelOptions(num_bins=num_mel_bins),
|
||||||
|
device="cuda",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
logging.info("Reading manifests")
|
||||||
|
manifests = {}
|
||||||
|
for part in ["ihm-mix", "sdm", "mdm8-bf"]:
|
||||||
|
manifests[part] = read_manifests_if_cached(
|
||||||
|
dataset_parts=["train", "dev", "test"],
|
||||||
|
output_dir=src_dir,
|
||||||
|
prefix=f"ami-{part}",
|
||||||
|
suffix="jsonl.gz",
|
||||||
|
)
|
||||||
|
|
||||||
|
for part in ["ihm-mix", "sdm", "mdm8-bf"]:
|
||||||
|
for split in ["train", "dev", "test"]:
|
||||||
|
logging.info(f"Processing {part} {split}")
|
||||||
|
cuts = CutSet.from_manifests(
|
||||||
|
**manifests[part][split]
|
||||||
|
).compute_and_store_features_batch(
|
||||||
|
extractor=extractor,
|
||||||
|
storage_path=output_dir / f"ami-{part}_{split}_feats",
|
||||||
|
manifest_path=src_dir / f"cuts_ami-{part}_{split}.jsonl.gz",
|
||||||
|
batch_duration=5000,
|
||||||
|
num_workers=4,
|
||||||
|
storage_type=LilcomChunkyWriter,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||||
|
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||||
|
|
||||||
|
compute_fbank_ami()
|
||||||
95
egs/ami/SURT/local/compute_fbank_icsi.py
Executable file
95
egs/ami/SURT/local/compute_fbank_icsi.py
Executable file
@ -0,0 +1,95 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# Copyright 2022 Johns Hopkins University (authors: Desh Raj)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
This file computes fbank features of the ICSI dataset.
|
||||||
|
We compute features for full recordings (i.e., without trimming to supervisions).
|
||||||
|
This way we can create arbitrary segmentations later.
|
||||||
|
|
||||||
|
The generated fbank features are saved in data/fbank.
|
||||||
|
"""
|
||||||
|
import logging
|
||||||
|
import math
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.multiprocessing
|
||||||
|
from lhotse import CutSet, LilcomChunkyWriter
|
||||||
|
from lhotse.features.kaldifeat import (
|
||||||
|
KaldifeatFbank,
|
||||||
|
KaldifeatFbankConfig,
|
||||||
|
KaldifeatFrameOptions,
|
||||||
|
KaldifeatMelOptions,
|
||||||
|
)
|
||||||
|
from lhotse.recipes.utils import read_manifests_if_cached
|
||||||
|
|
||||||
|
# Torch's multithreaded behavior needs to be disabled or
|
||||||
|
# it wastes a lot of CPU and slow things down.
|
||||||
|
# Do this outside of main() in case it needs to take effect
|
||||||
|
# even when we are not invoking the main (e.g. when spawning subprocesses).
|
||||||
|
torch.set_num_threads(1)
|
||||||
|
torch.set_num_interop_threads(1)
|
||||||
|
torch.multiprocessing.set_sharing_strategy("file_system")
|
||||||
|
|
||||||
|
|
||||||
|
def compute_fbank_icsi():
|
||||||
|
src_dir = Path("data/manifests")
|
||||||
|
output_dir = Path("data/fbank")
|
||||||
|
|
||||||
|
sampling_rate = 16000
|
||||||
|
num_mel_bins = 80
|
||||||
|
|
||||||
|
extractor = KaldifeatFbank(
|
||||||
|
KaldifeatFbankConfig(
|
||||||
|
frame_opts=KaldifeatFrameOptions(sampling_rate=sampling_rate),
|
||||||
|
mel_opts=KaldifeatMelOptions(num_bins=num_mel_bins),
|
||||||
|
device="cuda",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
logging.info("Reading manifests")
|
||||||
|
manifests = {}
|
||||||
|
for part in ["ihm-mix", "sdm"]:
|
||||||
|
manifests[part] = read_manifests_if_cached(
|
||||||
|
dataset_parts=["train"],
|
||||||
|
output_dir=src_dir,
|
||||||
|
prefix=f"icsi-{part}",
|
||||||
|
suffix="jsonl.gz",
|
||||||
|
)
|
||||||
|
|
||||||
|
for part in ["ihm-mix", "sdm"]:
|
||||||
|
for split in ["train"]:
|
||||||
|
logging.info(f"Processing {part} {split}")
|
||||||
|
cuts = CutSet.from_manifests(
|
||||||
|
**manifests[part][split]
|
||||||
|
).compute_and_store_features_batch(
|
||||||
|
extractor=extractor,
|
||||||
|
storage_path=output_dir / f"icsi-{part}_{split}_feats",
|
||||||
|
manifest_path=src_dir / f"cuts_icsi-{part}_{split}.jsonl.gz",
|
||||||
|
batch_duration=5000,
|
||||||
|
num_workers=4,
|
||||||
|
storage_type=LilcomChunkyWriter,
|
||||||
|
overwrite=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||||
|
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||||
|
|
||||||
|
compute_fbank_icsi()
|
||||||
101
egs/ami/SURT/local/compute_fbank_ihm.py
Executable file
101
egs/ami/SURT/local/compute_fbank_ihm.py
Executable file
@ -0,0 +1,101 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# Copyright 2022 Johns Hopkins University (authors: Desh Raj)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
This file computes fbank features of the trimmed sub-segments which will be
|
||||||
|
used for simulating the training mixtures.
|
||||||
|
|
||||||
|
The generated fbank features are saved in data/fbank.
|
||||||
|
"""
|
||||||
|
import logging
|
||||||
|
import math
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.multiprocessing
|
||||||
|
import torchaudio
|
||||||
|
from lhotse import CutSet, LilcomChunkyWriter, load_manifest
|
||||||
|
from lhotse.audio import set_ffmpeg_torchaudio_info_enabled
|
||||||
|
from lhotse.features.kaldifeat import (
|
||||||
|
KaldifeatFbank,
|
||||||
|
KaldifeatFbankConfig,
|
||||||
|
KaldifeatFrameOptions,
|
||||||
|
KaldifeatMelOptions,
|
||||||
|
)
|
||||||
|
from lhotse.recipes.utils import read_manifests_if_cached
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
# Torch's multithreaded behavior needs to be disabled or
|
||||||
|
# it wastes a lot of CPU and slow things down.
|
||||||
|
# Do this outside of main() in case it needs to take effect
|
||||||
|
# even when we are not invoking the main (e.g. when spawning subprocesses).
|
||||||
|
torch.set_num_threads(1)
|
||||||
|
torch.set_num_interop_threads(1)
|
||||||
|
torch.multiprocessing.set_sharing_strategy("file_system")
|
||||||
|
torchaudio.set_audio_backend("soundfile")
|
||||||
|
set_ffmpeg_torchaudio_info_enabled(False)
|
||||||
|
|
||||||
|
|
||||||
|
def compute_fbank_ihm():
|
||||||
|
src_dir = Path("data/manifests")
|
||||||
|
output_dir = Path("data/fbank")
|
||||||
|
|
||||||
|
sampling_rate = 16000
|
||||||
|
num_mel_bins = 80
|
||||||
|
|
||||||
|
extractor = KaldifeatFbank(
|
||||||
|
KaldifeatFbankConfig(
|
||||||
|
frame_opts=KaldifeatFrameOptions(sampling_rate=sampling_rate),
|
||||||
|
mel_opts=KaldifeatMelOptions(num_bins=num_mel_bins),
|
||||||
|
device="cuda",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
logging.info("Reading manifests")
|
||||||
|
manifests = {}
|
||||||
|
for data in ["ami", "icsi"]:
|
||||||
|
manifests[data] = read_manifests_if_cached(
|
||||||
|
dataset_parts=["train"],
|
||||||
|
output_dir=src_dir,
|
||||||
|
types=["recordings", "supervisions"],
|
||||||
|
prefix=f"{data}-ihm",
|
||||||
|
suffix="jsonl.gz",
|
||||||
|
)
|
||||||
|
|
||||||
|
logging.info("Computing features")
|
||||||
|
for data in ["ami", "icsi"]:
|
||||||
|
cs = CutSet.from_manifests(**manifests[data]["train"])
|
||||||
|
cs = cs.trim_to_supervisions(keep_overlapping=False)
|
||||||
|
cs = cs.normalize_loudness(target=-23.0, affix_id=False)
|
||||||
|
cs = cs + cs.perturb_speed(0.9) + cs.perturb_speed(1.1)
|
||||||
|
_ = cs.compute_and_store_features_batch(
|
||||||
|
extractor=extractor,
|
||||||
|
storage_path=output_dir / f"{data}-ihm_train_feats",
|
||||||
|
manifest_path=src_dir / f"{data}-ihm_cuts_train.jsonl.gz",
|
||||||
|
batch_duration=5000,
|
||||||
|
num_workers=4,
|
||||||
|
storage_type=LilcomChunkyWriter,
|
||||||
|
overwrite=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||||
|
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||||
|
|
||||||
|
compute_fbank_ihm()
|
||||||
146
egs/ami/SURT/local/prepare_ami_train_cuts.py
Executable file
146
egs/ami/SURT/local/prepare_ami_train_cuts.py
Executable file
@ -0,0 +1,146 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# Copyright 2022 Johns Hopkins University (authors: Desh Raj)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
This file creates AMI train segments.
|
||||||
|
"""
|
||||||
|
import logging
|
||||||
|
import math
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.multiprocessing
|
||||||
|
from lhotse import LilcomChunkyWriter, load_manifest_lazy
|
||||||
|
from lhotse.cut import Cut, CutSet
|
||||||
|
from lhotse.utils import EPSILON, add_durations
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
|
||||||
|
def cut_into_windows(cuts: CutSet, duration: float):
|
||||||
|
"""
|
||||||
|
This function takes a CutSet and cuts each cut into windows of roughly
|
||||||
|
`duration` seconds. By roughly, we mean that we try to adjust for the last supervision
|
||||||
|
that exceeds the duration, or is shorter than the duration.
|
||||||
|
"""
|
||||||
|
res = []
|
||||||
|
with tqdm() as pbar:
|
||||||
|
for cut in cuts:
|
||||||
|
pbar.update(1)
|
||||||
|
sups = cut.index_supervisions()[cut.id]
|
||||||
|
sr = cut.sampling_rate
|
||||||
|
start = 0.0
|
||||||
|
end = duration
|
||||||
|
num_tries = 0
|
||||||
|
while start < cut.duration and num_tries < 2:
|
||||||
|
# Find the supervision that are cut by the window endpoint
|
||||||
|
hitlist = [iv for iv in sups.at(end) if iv.begin < end]
|
||||||
|
# If there are no supervisions, we are done
|
||||||
|
if not hitlist:
|
||||||
|
res.append(
|
||||||
|
cut.truncate(
|
||||||
|
offset=start,
|
||||||
|
duration=add_durations(end, -start, sampling_rate=sr),
|
||||||
|
keep_excessive_supervisions=False,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
# Update the start and end for the next window
|
||||||
|
start = end
|
||||||
|
end = add_durations(end, duration, sampling_rate=sr)
|
||||||
|
else:
|
||||||
|
# find ratio of durations cut by the window endpoint
|
||||||
|
ratios = [
|
||||||
|
add_durations(end, -iv.end, sampling_rate=sr) / iv.length()
|
||||||
|
for iv in hitlist
|
||||||
|
]
|
||||||
|
# we retain the supervisions that have >50% of their duration
|
||||||
|
# in the window, and discard the others
|
||||||
|
retained = []
|
||||||
|
discarded = []
|
||||||
|
for iv, ratio in zip(hitlist, ratios):
|
||||||
|
if ratio > 0.5:
|
||||||
|
retained.append(iv)
|
||||||
|
else:
|
||||||
|
discarded.append(iv)
|
||||||
|
cur_end = max(iv.end for iv in retained) if retained else end
|
||||||
|
res.append(
|
||||||
|
cut.truncate(
|
||||||
|
offset=start,
|
||||||
|
duration=add_durations(cur_end, -start, sampling_rate=sr),
|
||||||
|
keep_excessive_supervisions=False,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
# For the next window, we start at the earliest discarded supervision
|
||||||
|
next_start = min(iv.begin for iv in discarded) if discarded else end
|
||||||
|
next_end = add_durations(next_start, duration, sampling_rate=sr)
|
||||||
|
# It may happen that next_start is the same as start, in which case
|
||||||
|
# we will advance the window anyway
|
||||||
|
if next_start == start:
|
||||||
|
logging.warning(
|
||||||
|
f"Next start is the same as start: {next_start} == {start} for cut {cut.id}"
|
||||||
|
)
|
||||||
|
start = end + EPSILON
|
||||||
|
end = add_durations(start, duration, sampling_rate=sr)
|
||||||
|
num_tries += 1
|
||||||
|
else:
|
||||||
|
start = next_start
|
||||||
|
end = next_end
|
||||||
|
return CutSet.from_cuts(res)
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_train_cuts():
|
||||||
|
src_dir = Path("data/manifests")
|
||||||
|
|
||||||
|
logging.info("Loading the manifests")
|
||||||
|
train_cuts_ihm = load_manifest_lazy(
|
||||||
|
src_dir / "cuts_ami-ihm-mix_train.jsonl.gz"
|
||||||
|
).map(lambda c: c.with_id(f"{c.id}_ihm-mix"))
|
||||||
|
train_cuts_sdm = load_manifest_lazy(src_dir / "cuts_ami-sdm_train.jsonl.gz").map(
|
||||||
|
lambda c: c.with_id(f"{c.id}_sdm")
|
||||||
|
)
|
||||||
|
train_cuts_mdm = load_manifest_lazy(
|
||||||
|
src_dir / "cuts_ami-mdm8-bf_train.jsonl.gz"
|
||||||
|
).map(lambda c: c.with_id(f"{c.id}_mdm8-bf"))
|
||||||
|
|
||||||
|
# Combine all cuts into one CutSet
|
||||||
|
train_cuts = train_cuts_ihm + train_cuts_sdm + train_cuts_mdm
|
||||||
|
|
||||||
|
train_cuts_1 = train_cuts.trim_to_supervision_groups(max_pause=0.5)
|
||||||
|
train_cuts_2 = train_cuts.trim_to_supervision_groups(max_pause=0.0)
|
||||||
|
|
||||||
|
# Combine the two segmentations
|
||||||
|
train_all = train_cuts_1 + train_cuts_2
|
||||||
|
|
||||||
|
# At this point, some of the cuts may be very long. We will cut them into windows of
|
||||||
|
# roughly 30 seconds.
|
||||||
|
logging.info("Cutting the segments into windows of 30 seconds")
|
||||||
|
train_all_30 = cut_into_windows(train_all, duration=30.0)
|
||||||
|
logging.info(f"Number of cuts after cutting into windows: {len(train_all_30)}")
|
||||||
|
|
||||||
|
# Show statistics
|
||||||
|
train_all.describe(full=True)
|
||||||
|
|
||||||
|
# Save the cuts
|
||||||
|
logging.info("Saving the cuts")
|
||||||
|
train_all.to_file(src_dir / "cuts_train_ami.jsonl.gz")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||||
|
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||||
|
|
||||||
|
prepare_train_cuts()
|
||||||
67
egs/ami/SURT/local/prepare_icsi_train_cuts.py
Executable file
67
egs/ami/SURT/local/prepare_icsi_train_cuts.py
Executable file
@ -0,0 +1,67 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# Copyright 2022 Johns Hopkins University (authors: Desh Raj)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
This file creates ICSI train segments.
|
||||||
|
"""
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from lhotse import load_manifest_lazy
|
||||||
|
from prepare_ami_train_cuts import cut_into_windows
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_train_cuts():
|
||||||
|
src_dir = Path("data/manifests")
|
||||||
|
|
||||||
|
logging.info("Loading the manifests")
|
||||||
|
train_cuts_ihm = load_manifest_lazy(
|
||||||
|
src_dir / "cuts_icsi-ihm-mix_train.jsonl.gz"
|
||||||
|
).map(lambda c: c.with_id(f"{c.id}_ihm-mix"))
|
||||||
|
train_cuts_sdm = load_manifest_lazy(src_dir / "cuts_icsi-sdm_train.jsonl.gz").map(
|
||||||
|
lambda c: c.with_id(f"{c.id}_sdm")
|
||||||
|
)
|
||||||
|
|
||||||
|
# Combine all cuts into one CutSet
|
||||||
|
train_cuts = train_cuts_ihm + train_cuts_sdm
|
||||||
|
|
||||||
|
train_cuts_1 = train_cuts.trim_to_supervision_groups(max_pause=0.5)
|
||||||
|
train_cuts_2 = train_cuts.trim_to_supervision_groups(max_pause=0.0)
|
||||||
|
|
||||||
|
# Combine the two segmentations
|
||||||
|
train_all = train_cuts_1 + train_cuts_2
|
||||||
|
|
||||||
|
# At this point, some of the cuts may be very long. We will cut them into windows of
|
||||||
|
# roughly 30 seconds.
|
||||||
|
logging.info("Cutting the segments into windows of 30 seconds")
|
||||||
|
train_all_30 = cut_into_windows(train_all, duration=30.0)
|
||||||
|
logging.info(f"Number of cuts after cutting into windows: {len(train_all_30)}")
|
||||||
|
|
||||||
|
# Show statistics
|
||||||
|
train_all.describe(full=True)
|
||||||
|
|
||||||
|
# Save the cuts
|
||||||
|
logging.info("Saving the cuts")
|
||||||
|
train_all.to_file(src_dir / "cuts_train_icsi.jsonl.gz")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||||
|
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||||
|
|
||||||
|
prepare_train_cuts()
|
||||||
413
egs/ami/SURT/local/prepare_lang.py
Executable file
413
egs/ami/SURT/local/prepare_lang.py
Executable file
@ -0,0 +1,413 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
This script takes as input a lexicon file "data/lang_phone/lexicon.txt"
|
||||||
|
consisting of words and tokens (i.e., phones) and does the following:
|
||||||
|
|
||||||
|
1. Add disambiguation symbols to the lexicon and generate lexicon_disambig.txt
|
||||||
|
|
||||||
|
2. Generate tokens.txt, the token table mapping a token to a unique integer.
|
||||||
|
|
||||||
|
3. Generate words.txt, the word table mapping a word to a unique integer.
|
||||||
|
|
||||||
|
4. Generate L.pt, in k2 format. It can be loaded by
|
||||||
|
|
||||||
|
d = torch.load("L.pt")
|
||||||
|
lexicon = k2.Fsa.from_dict(d)
|
||||||
|
|
||||||
|
5. Generate L_disambig.pt, in k2 format.
|
||||||
|
"""
|
||||||
|
import argparse
|
||||||
|
import math
|
||||||
|
from collections import defaultdict
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Dict, List, Tuple
|
||||||
|
|
||||||
|
import k2
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from icefall.lexicon import read_lexicon, write_lexicon
|
||||||
|
from icefall.utils import str2bool
|
||||||
|
|
||||||
|
Lexicon = List[Tuple[str, List[str]]]
|
||||||
|
|
||||||
|
|
||||||
|
def get_args():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument(
|
||||||
|
"--lang-dir",
|
||||||
|
type=str,
|
||||||
|
help="""Input and output directory.
|
||||||
|
It should contain a file lexicon.txt.
|
||||||
|
Generated files by this script are saved into this directory.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--debug",
|
||||||
|
type=str2bool,
|
||||||
|
default=False,
|
||||||
|
help="""True for debugging, which will generate
|
||||||
|
a visualization of the lexicon FST.
|
||||||
|
|
||||||
|
Caution: If your lexicon contains hundreds of thousands
|
||||||
|
of lines, please set it to False!
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def write_mapping(filename: str, sym2id: Dict[str, int]) -> None:
|
||||||
|
"""Write a symbol to ID mapping to a file.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
No need to implement `read_mapping` as it can be done
|
||||||
|
through :func:`k2.SymbolTable.from_file`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
filename:
|
||||||
|
Filename to save the mapping.
|
||||||
|
sym2id:
|
||||||
|
A dict mapping symbols to IDs.
|
||||||
|
Returns:
|
||||||
|
Return None.
|
||||||
|
"""
|
||||||
|
with open(filename, "w", encoding="utf-8") as f:
|
||||||
|
for sym, i in sym2id.items():
|
||||||
|
f.write(f"{sym} {i}\n")
|
||||||
|
|
||||||
|
|
||||||
|
def get_tokens(lexicon: Lexicon) -> List[str]:
|
||||||
|
"""Get tokens from a lexicon.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
lexicon:
|
||||||
|
It is the return value of :func:`read_lexicon`.
|
||||||
|
Returns:
|
||||||
|
Return a list of unique tokens.
|
||||||
|
"""
|
||||||
|
ans = set()
|
||||||
|
for _, tokens in lexicon:
|
||||||
|
ans.update(tokens)
|
||||||
|
sorted_ans = sorted(list(ans))
|
||||||
|
return sorted_ans
|
||||||
|
|
||||||
|
|
||||||
|
def get_words(lexicon: Lexicon) -> List[str]:
|
||||||
|
"""Get words from a lexicon.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
lexicon:
|
||||||
|
It is the return value of :func:`read_lexicon`.
|
||||||
|
Returns:
|
||||||
|
Return a list of unique words.
|
||||||
|
"""
|
||||||
|
ans = set()
|
||||||
|
for word, _ in lexicon:
|
||||||
|
ans.add(word)
|
||||||
|
sorted_ans = sorted(list(ans))
|
||||||
|
return sorted_ans
|
||||||
|
|
||||||
|
|
||||||
|
def add_disambig_symbols(lexicon: Lexicon) -> Tuple[Lexicon, int]:
|
||||||
|
"""It adds pseudo-token disambiguation symbols #1, #2 and so on
|
||||||
|
at the ends of tokens to ensure that all pronunciations are different,
|
||||||
|
and that none is a prefix of another.
|
||||||
|
|
||||||
|
See also add_lex_disambig.pl from kaldi.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
lexicon:
|
||||||
|
It is returned by :func:`read_lexicon`.
|
||||||
|
Returns:
|
||||||
|
Return a tuple with two elements:
|
||||||
|
|
||||||
|
- The output lexicon with disambiguation symbols
|
||||||
|
- The ID of the max disambiguation symbol that appears
|
||||||
|
in the lexicon
|
||||||
|
"""
|
||||||
|
|
||||||
|
# (1) Work out the count of each token-sequence in the
|
||||||
|
# lexicon.
|
||||||
|
count = defaultdict(int)
|
||||||
|
for _, tokens in lexicon:
|
||||||
|
count[" ".join(tokens)] += 1
|
||||||
|
|
||||||
|
# (2) For each left sub-sequence of each token-sequence, note down
|
||||||
|
# that it exists (for identifying prefixes of longer strings).
|
||||||
|
issubseq = defaultdict(int)
|
||||||
|
for _, tokens in lexicon:
|
||||||
|
tokens = tokens.copy()
|
||||||
|
tokens.pop()
|
||||||
|
while tokens:
|
||||||
|
issubseq[" ".join(tokens)] = 1
|
||||||
|
tokens.pop()
|
||||||
|
|
||||||
|
# (3) For each entry in the lexicon:
|
||||||
|
# if the token sequence is unique and is not a
|
||||||
|
# prefix of another word, no disambig symbol.
|
||||||
|
# Else output #1, or #2, #3, ... if the same token-seq
|
||||||
|
# has already been assigned a disambig symbol.
|
||||||
|
ans = []
|
||||||
|
|
||||||
|
# We start with #1 since #0 has its own purpose
|
||||||
|
first_allowed_disambig = 1
|
||||||
|
max_disambig = first_allowed_disambig - 1
|
||||||
|
last_used_disambig_symbol_of = defaultdict(int)
|
||||||
|
|
||||||
|
for word, tokens in lexicon:
|
||||||
|
tokenseq = " ".join(tokens)
|
||||||
|
assert tokenseq != ""
|
||||||
|
if issubseq[tokenseq] == 0 and count[tokenseq] == 1:
|
||||||
|
ans.append((word, tokens))
|
||||||
|
continue
|
||||||
|
|
||||||
|
cur_disambig = last_used_disambig_symbol_of[tokenseq]
|
||||||
|
if cur_disambig == 0:
|
||||||
|
cur_disambig = first_allowed_disambig
|
||||||
|
else:
|
||||||
|
cur_disambig += 1
|
||||||
|
|
||||||
|
if cur_disambig > max_disambig:
|
||||||
|
max_disambig = cur_disambig
|
||||||
|
last_used_disambig_symbol_of[tokenseq] = cur_disambig
|
||||||
|
tokenseq += f" #{cur_disambig}"
|
||||||
|
ans.append((word, tokenseq.split()))
|
||||||
|
return ans, max_disambig
|
||||||
|
|
||||||
|
|
||||||
|
def generate_id_map(symbols: List[str]) -> Dict[str, int]:
|
||||||
|
"""Generate ID maps, i.e., map a symbol to a unique ID.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
symbols:
|
||||||
|
A list of unique symbols.
|
||||||
|
Returns:
|
||||||
|
A dict containing the mapping between symbols and IDs.
|
||||||
|
"""
|
||||||
|
return {sym: i for i, sym in enumerate(symbols)}
|
||||||
|
|
||||||
|
|
||||||
|
def add_self_loops(
|
||||||
|
arcs: List[List[Any]], disambig_token: int, disambig_word: int
|
||||||
|
) -> List[List[Any]]:
|
||||||
|
"""Adds self-loops to states of an FST to propagate disambiguation symbols
|
||||||
|
through it. They are added on each state with non-epsilon output symbols
|
||||||
|
on at least one arc out of the state.
|
||||||
|
|
||||||
|
See also fstaddselfloops.pl from Kaldi. One difference is that
|
||||||
|
Kaldi uses OpenFst style FSTs and it has multiple final states.
|
||||||
|
This function uses k2 style FSTs and it does not need to add self-loops
|
||||||
|
to the final state.
|
||||||
|
|
||||||
|
The input label of a self-loop is `disambig_token`, while the output
|
||||||
|
label is `disambig_word`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
arcs:
|
||||||
|
A list-of-list. The sublist contains
|
||||||
|
`[src_state, dest_state, label, aux_label, score]`
|
||||||
|
disambig_token:
|
||||||
|
It is the token ID of the symbol `#0`.
|
||||||
|
disambig_word:
|
||||||
|
It is the word ID of the symbol `#0`.
|
||||||
|
|
||||||
|
Return:
|
||||||
|
Return new `arcs` containing self-loops.
|
||||||
|
"""
|
||||||
|
states_needs_self_loops = set()
|
||||||
|
for arc in arcs:
|
||||||
|
src, dst, ilabel, olabel, score = arc
|
||||||
|
if olabel != 0:
|
||||||
|
states_needs_self_loops.add(src)
|
||||||
|
|
||||||
|
ans = []
|
||||||
|
for s in states_needs_self_loops:
|
||||||
|
ans.append([s, s, disambig_token, disambig_word, 0])
|
||||||
|
|
||||||
|
return arcs + ans
|
||||||
|
|
||||||
|
|
||||||
|
def lexicon_to_fst(
|
||||||
|
lexicon: Lexicon,
|
||||||
|
token2id: Dict[str, int],
|
||||||
|
word2id: Dict[str, int],
|
||||||
|
sil_token: str = "SIL",
|
||||||
|
sil_prob: float = 0.5,
|
||||||
|
need_self_loops: bool = False,
|
||||||
|
) -> k2.Fsa:
|
||||||
|
"""Convert a lexicon to an FST (in k2 format) with optional silence at
|
||||||
|
the beginning and end of each word.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
lexicon:
|
||||||
|
The input lexicon. See also :func:`read_lexicon`
|
||||||
|
token2id:
|
||||||
|
A dict mapping tokens to IDs.
|
||||||
|
word2id:
|
||||||
|
A dict mapping words to IDs.
|
||||||
|
sil_token:
|
||||||
|
The silence token.
|
||||||
|
sil_prob:
|
||||||
|
The probability for adding a silence at the beginning and end
|
||||||
|
of the word.
|
||||||
|
need_self_loops:
|
||||||
|
If True, add self-loop to states with non-epsilon output symbols
|
||||||
|
on at least one arc out of the state. The input label for this
|
||||||
|
self loop is `token2id["#0"]` and the output label is `word2id["#0"]`.
|
||||||
|
Returns:
|
||||||
|
Return an instance of `k2.Fsa` representing the given lexicon.
|
||||||
|
"""
|
||||||
|
assert sil_prob > 0.0 and sil_prob < 1.0
|
||||||
|
# CAUTION: we use score, i.e, negative cost.
|
||||||
|
sil_score = math.log(sil_prob)
|
||||||
|
no_sil_score = math.log(1.0 - sil_prob)
|
||||||
|
|
||||||
|
start_state = 0
|
||||||
|
loop_state = 1 # words enter and leave from here
|
||||||
|
sil_state = 2 # words terminate here when followed by silence; this state
|
||||||
|
# has a silence transition to loop_state.
|
||||||
|
next_state = 3 # the next un-allocated state, will be incremented as we go.
|
||||||
|
arcs = []
|
||||||
|
|
||||||
|
assert token2id["<eps>"] == 0
|
||||||
|
assert word2id["<eps>"] == 0
|
||||||
|
|
||||||
|
eps = 0
|
||||||
|
|
||||||
|
sil_token = token2id[sil_token]
|
||||||
|
|
||||||
|
arcs.append([start_state, loop_state, eps, eps, no_sil_score])
|
||||||
|
arcs.append([start_state, sil_state, eps, eps, sil_score])
|
||||||
|
arcs.append([sil_state, loop_state, sil_token, eps, 0])
|
||||||
|
|
||||||
|
for word, tokens in lexicon:
|
||||||
|
assert len(tokens) > 0, f"{word} has no pronunciations"
|
||||||
|
cur_state = loop_state
|
||||||
|
|
||||||
|
word = word2id[word]
|
||||||
|
tokens = [token2id[i] for i in tokens]
|
||||||
|
|
||||||
|
for i in range(len(tokens) - 1):
|
||||||
|
w = word if i == 0 else eps
|
||||||
|
arcs.append([cur_state, next_state, tokens[i], w, 0])
|
||||||
|
|
||||||
|
cur_state = next_state
|
||||||
|
next_state += 1
|
||||||
|
|
||||||
|
# now for the last token of this word
|
||||||
|
# It has two out-going arcs, one to the loop state,
|
||||||
|
# the other one to the sil_state.
|
||||||
|
i = len(tokens) - 1
|
||||||
|
w = word if i == 0 else eps
|
||||||
|
arcs.append([cur_state, loop_state, tokens[i], w, no_sil_score])
|
||||||
|
arcs.append([cur_state, sil_state, tokens[i], w, sil_score])
|
||||||
|
|
||||||
|
if need_self_loops:
|
||||||
|
disambig_token = token2id["#0"]
|
||||||
|
disambig_word = word2id["#0"]
|
||||||
|
arcs = add_self_loops(
|
||||||
|
arcs,
|
||||||
|
disambig_token=disambig_token,
|
||||||
|
disambig_word=disambig_word,
|
||||||
|
)
|
||||||
|
|
||||||
|
final_state = next_state
|
||||||
|
arcs.append([loop_state, final_state, -1, -1, 0])
|
||||||
|
arcs.append([final_state])
|
||||||
|
|
||||||
|
arcs = sorted(arcs, key=lambda arc: arc[0])
|
||||||
|
arcs = [[str(i) for i in arc] for arc in arcs]
|
||||||
|
arcs = [" ".join(arc) for arc in arcs]
|
||||||
|
arcs = "\n".join(arcs)
|
||||||
|
|
||||||
|
fsa = k2.Fsa.from_str(arcs, acceptor=False)
|
||||||
|
return fsa
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
args = get_args()
|
||||||
|
lang_dir = Path(args.lang_dir)
|
||||||
|
lexicon_filename = lang_dir / "lexicon.txt"
|
||||||
|
sil_token = "SIL"
|
||||||
|
sil_prob = 0.5
|
||||||
|
|
||||||
|
lexicon = read_lexicon(lexicon_filename)
|
||||||
|
tokens = get_tokens(lexicon)
|
||||||
|
words = get_words(lexicon)
|
||||||
|
|
||||||
|
lexicon_disambig, max_disambig = add_disambig_symbols(lexicon)
|
||||||
|
|
||||||
|
for i in range(max_disambig + 1):
|
||||||
|
disambig = f"#{i}"
|
||||||
|
assert disambig not in tokens
|
||||||
|
tokens.append(f"#{i}")
|
||||||
|
|
||||||
|
assert "<eps>" not in tokens
|
||||||
|
tokens = ["<eps>"] + tokens
|
||||||
|
|
||||||
|
assert "<eps>" not in words
|
||||||
|
assert "#0" not in words
|
||||||
|
assert "<s>" not in words
|
||||||
|
assert "</s>" not in words
|
||||||
|
|
||||||
|
words = ["<eps>"] + words + ["#0", "<s>", "</s>"]
|
||||||
|
|
||||||
|
token2id = generate_id_map(tokens)
|
||||||
|
word2id = generate_id_map(words)
|
||||||
|
|
||||||
|
write_mapping(lang_dir / "tokens.txt", token2id)
|
||||||
|
write_mapping(lang_dir / "words.txt", word2id)
|
||||||
|
write_lexicon(lang_dir / "lexicon_disambig.txt", lexicon_disambig)
|
||||||
|
|
||||||
|
L = lexicon_to_fst(
|
||||||
|
lexicon,
|
||||||
|
token2id=token2id,
|
||||||
|
word2id=word2id,
|
||||||
|
sil_token=sil_token,
|
||||||
|
sil_prob=sil_prob,
|
||||||
|
)
|
||||||
|
|
||||||
|
L_disambig = lexicon_to_fst(
|
||||||
|
lexicon_disambig,
|
||||||
|
token2id=token2id,
|
||||||
|
word2id=word2id,
|
||||||
|
sil_token=sil_token,
|
||||||
|
sil_prob=sil_prob,
|
||||||
|
need_self_loops=True,
|
||||||
|
)
|
||||||
|
torch.save(L.as_dict(), lang_dir / "L.pt")
|
||||||
|
torch.save(L_disambig.as_dict(), lang_dir / "L_disambig.pt")
|
||||||
|
|
||||||
|
if args.debug:
|
||||||
|
labels_sym = k2.SymbolTable.from_file(lang_dir / "tokens.txt")
|
||||||
|
aux_labels_sym = k2.SymbolTable.from_file(lang_dir / "words.txt")
|
||||||
|
|
||||||
|
L.labels_sym = labels_sym
|
||||||
|
L.aux_labels_sym = aux_labels_sym
|
||||||
|
L.draw(f"{lang_dir / 'L.svg'}", title="L.pt")
|
||||||
|
|
||||||
|
L_disambig.labels_sym = labels_sym
|
||||||
|
L_disambig.aux_labels_sym = aux_labels_sym
|
||||||
|
L_disambig.draw(f"{lang_dir / 'L_disambig.svg'}", title="L_disambig.pt")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
266
egs/ami/SURT/local/prepare_lang_bpe.py
Executable file
266
egs/ami/SURT/local/prepare_lang_bpe.py
Executable file
@ -0,0 +1,266 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang)
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
This script takes as input `lang_dir`, which should contain::
|
||||||
|
|
||||||
|
- lang_dir/bpe.model,
|
||||||
|
- lang_dir/words.txt
|
||||||
|
|
||||||
|
and generates the following files in the directory `lang_dir`:
|
||||||
|
|
||||||
|
- lexicon.txt
|
||||||
|
- lexicon_disambig.txt
|
||||||
|
- L.pt
|
||||||
|
- L_disambig.pt
|
||||||
|
- tokens.txt
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, List, Tuple
|
||||||
|
|
||||||
|
import k2
|
||||||
|
import sentencepiece as spm
|
||||||
|
import torch
|
||||||
|
from prepare_lang import (
|
||||||
|
Lexicon,
|
||||||
|
add_disambig_symbols,
|
||||||
|
add_self_loops,
|
||||||
|
write_lexicon,
|
||||||
|
write_mapping,
|
||||||
|
)
|
||||||
|
|
||||||
|
from icefall.utils import str2bool
|
||||||
|
|
||||||
|
|
||||||
|
def lexicon_to_fst_no_sil(
|
||||||
|
lexicon: Lexicon,
|
||||||
|
token2id: Dict[str, int],
|
||||||
|
word2id: Dict[str, int],
|
||||||
|
need_self_loops: bool = False,
|
||||||
|
) -> k2.Fsa:
|
||||||
|
"""Convert a lexicon to an FST (in k2 format).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
lexicon:
|
||||||
|
The input lexicon. See also :func:`read_lexicon`
|
||||||
|
token2id:
|
||||||
|
A dict mapping tokens to IDs.
|
||||||
|
word2id:
|
||||||
|
A dict mapping words to IDs.
|
||||||
|
need_self_loops:
|
||||||
|
If True, add self-loop to states with non-epsilon output symbols
|
||||||
|
on at least one arc out of the state. The input label for this
|
||||||
|
self loop is `token2id["#0"]` and the output label is `word2id["#0"]`.
|
||||||
|
Returns:
|
||||||
|
Return an instance of `k2.Fsa` representing the given lexicon.
|
||||||
|
"""
|
||||||
|
loop_state = 0 # words enter and leave from here
|
||||||
|
next_state = 1 # the next un-allocated state, will be incremented as we go
|
||||||
|
|
||||||
|
arcs = []
|
||||||
|
|
||||||
|
# The blank symbol <blk> is defined in local/train_bpe_model.py
|
||||||
|
assert token2id["<blk>"] == 0
|
||||||
|
assert word2id["<eps>"] == 0
|
||||||
|
|
||||||
|
eps = 0
|
||||||
|
|
||||||
|
for word, pieces in lexicon:
|
||||||
|
assert len(pieces) > 0, f"{word} has no pronunciations"
|
||||||
|
cur_state = loop_state
|
||||||
|
|
||||||
|
word = word2id[word]
|
||||||
|
pieces = [token2id[i] for i in pieces]
|
||||||
|
|
||||||
|
for i in range(len(pieces) - 1):
|
||||||
|
w = word if i == 0 else eps
|
||||||
|
arcs.append([cur_state, next_state, pieces[i], w, 0])
|
||||||
|
|
||||||
|
cur_state = next_state
|
||||||
|
next_state += 1
|
||||||
|
|
||||||
|
# now for the last piece of this word
|
||||||
|
i = len(pieces) - 1
|
||||||
|
w = word if i == 0 else eps
|
||||||
|
arcs.append([cur_state, loop_state, pieces[i], w, 0])
|
||||||
|
|
||||||
|
if need_self_loops:
|
||||||
|
disambig_token = token2id["#0"]
|
||||||
|
disambig_word = word2id["#0"]
|
||||||
|
arcs = add_self_loops(
|
||||||
|
arcs,
|
||||||
|
disambig_token=disambig_token,
|
||||||
|
disambig_word=disambig_word,
|
||||||
|
)
|
||||||
|
|
||||||
|
final_state = next_state
|
||||||
|
arcs.append([loop_state, final_state, -1, -1, 0])
|
||||||
|
arcs.append([final_state])
|
||||||
|
|
||||||
|
arcs = sorted(arcs, key=lambda arc: arc[0])
|
||||||
|
arcs = [[str(i) for i in arc] for arc in arcs]
|
||||||
|
arcs = [" ".join(arc) for arc in arcs]
|
||||||
|
arcs = "\n".join(arcs)
|
||||||
|
|
||||||
|
fsa = k2.Fsa.from_str(arcs, acceptor=False)
|
||||||
|
return fsa
|
||||||
|
|
||||||
|
|
||||||
|
def generate_lexicon(
|
||||||
|
model_file: str, words: List[str], oov: str
|
||||||
|
) -> Tuple[Lexicon, Dict[str, int]]:
|
||||||
|
"""Generate a lexicon from a BPE model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_file:
|
||||||
|
Path to a sentencepiece model.
|
||||||
|
words:
|
||||||
|
A list of strings representing words.
|
||||||
|
oov:
|
||||||
|
The out of vocabulary word in lexicon.
|
||||||
|
Returns:
|
||||||
|
Return a tuple with two elements:
|
||||||
|
- A dict whose keys are words and values are the corresponding
|
||||||
|
word pieces.
|
||||||
|
- A dict representing the token symbol, mapping from tokens to IDs.
|
||||||
|
"""
|
||||||
|
sp = spm.SentencePieceProcessor()
|
||||||
|
sp.load(str(model_file))
|
||||||
|
|
||||||
|
# Convert word to word piece IDs instead of word piece strings
|
||||||
|
# to avoid OOV tokens.
|
||||||
|
words_pieces_ids: List[List[int]] = sp.encode(words, out_type=int)
|
||||||
|
|
||||||
|
# Now convert word piece IDs back to word piece strings.
|
||||||
|
words_pieces: List[List[str]] = [sp.id_to_piece(ids) for ids in words_pieces_ids]
|
||||||
|
|
||||||
|
lexicon = []
|
||||||
|
for word, pieces in zip(words, words_pieces):
|
||||||
|
lexicon.append((word, pieces))
|
||||||
|
|
||||||
|
lexicon.append((oov, ["▁", sp.id_to_piece(sp.unk_id())]))
|
||||||
|
|
||||||
|
token2id: Dict[str, int] = {sp.id_to_piece(i): i for i in range(sp.vocab_size())}
|
||||||
|
|
||||||
|
return lexicon, token2id
|
||||||
|
|
||||||
|
|
||||||
|
def get_args():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument(
|
||||||
|
"--lang-dir",
|
||||||
|
type=str,
|
||||||
|
help="""Input and output directory.
|
||||||
|
It should contain the bpe.model and words.txt
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--oov",
|
||||||
|
type=str,
|
||||||
|
default="<UNK>",
|
||||||
|
help="The out of vocabulary word in lexicon.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--debug",
|
||||||
|
type=str2bool,
|
||||||
|
default=False,
|
||||||
|
help="""True for debugging, which will generate
|
||||||
|
a visualization of the lexicon FST.
|
||||||
|
|
||||||
|
Caution: If your lexicon contains hundreds of thousands
|
||||||
|
of lines, please set it to False!
|
||||||
|
|
||||||
|
See "test/test_bpe_lexicon.py" for usage.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
args = get_args()
|
||||||
|
lang_dir = Path(args.lang_dir)
|
||||||
|
model_file = lang_dir / "bpe.model"
|
||||||
|
|
||||||
|
word_sym_table = k2.SymbolTable.from_file(lang_dir / "words.txt")
|
||||||
|
|
||||||
|
words = word_sym_table.symbols
|
||||||
|
|
||||||
|
excluded = ["<eps>", "!SIL", "<SPOKEN_NOISE>", args.oov, "#0", "<s>", "</s>"]
|
||||||
|
|
||||||
|
for w in excluded:
|
||||||
|
if w in words:
|
||||||
|
words.remove(w)
|
||||||
|
|
||||||
|
lexicon, token_sym_table = generate_lexicon(model_file, words, args.oov)
|
||||||
|
|
||||||
|
lexicon_disambig, max_disambig = add_disambig_symbols(lexicon)
|
||||||
|
|
||||||
|
next_token_id = max(token_sym_table.values()) + 1
|
||||||
|
for i in range(max_disambig + 1):
|
||||||
|
disambig = f"#{i}"
|
||||||
|
assert disambig not in token_sym_table
|
||||||
|
token_sym_table[disambig] = next_token_id
|
||||||
|
next_token_id += 1
|
||||||
|
|
||||||
|
word_sym_table.add("#0")
|
||||||
|
word_sym_table.add("<s>")
|
||||||
|
word_sym_table.add("</s>")
|
||||||
|
|
||||||
|
write_mapping(lang_dir / "tokens.txt", token_sym_table)
|
||||||
|
|
||||||
|
write_lexicon(lang_dir / "lexicon.txt", lexicon)
|
||||||
|
write_lexicon(lang_dir / "lexicon_disambig.txt", lexicon_disambig)
|
||||||
|
|
||||||
|
L = lexicon_to_fst_no_sil(
|
||||||
|
lexicon,
|
||||||
|
token2id=token_sym_table,
|
||||||
|
word2id=word_sym_table,
|
||||||
|
)
|
||||||
|
|
||||||
|
L_disambig = lexicon_to_fst_no_sil(
|
||||||
|
lexicon_disambig,
|
||||||
|
token2id=token_sym_table,
|
||||||
|
word2id=word_sym_table,
|
||||||
|
need_self_loops=True,
|
||||||
|
)
|
||||||
|
torch.save(L.as_dict(), lang_dir / "L.pt")
|
||||||
|
torch.save(L_disambig.as_dict(), lang_dir / "L_disambig.pt")
|
||||||
|
|
||||||
|
if args.debug:
|
||||||
|
labels_sym = k2.SymbolTable.from_file(lang_dir / "tokens.txt")
|
||||||
|
aux_labels_sym = k2.SymbolTable.from_file(lang_dir / "words.txt")
|
||||||
|
|
||||||
|
L.labels_sym = labels_sym
|
||||||
|
L.aux_labels_sym = aux_labels_sym
|
||||||
|
L.draw(f"{lang_dir / 'L.svg'}", title="L.pt")
|
||||||
|
|
||||||
|
L_disambig.labels_sym = labels_sym
|
||||||
|
L_disambig.aux_labels_sym = aux_labels_sym
|
||||||
|
L_disambig.draw(f"{lang_dir / 'L_disambig.svg'}", title="L_disambig.pt")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
100
egs/ami/SURT/local/train_bpe_model.py
Executable file
100
egs/ami/SURT/local/train_bpe_model.py
Executable file
@ -0,0 +1,100 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
# You can install sentencepiece via:
|
||||||
|
#
|
||||||
|
# pip install sentencepiece
|
||||||
|
#
|
||||||
|
# Due to an issue reported in
|
||||||
|
# https://github.com/google/sentencepiece/pull/642#issuecomment-857972030
|
||||||
|
#
|
||||||
|
# Please install a version >=0.1.96
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import shutil
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import sentencepiece as spm
|
||||||
|
|
||||||
|
|
||||||
|
def get_args():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument(
|
||||||
|
"--lang-dir",
|
||||||
|
type=str,
|
||||||
|
help="""Input and output directory.
|
||||||
|
The generated bpe.model is saved to this directory.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--transcript",
|
||||||
|
type=str,
|
||||||
|
help="Training transcript.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--vocab-size",
|
||||||
|
type=int,
|
||||||
|
help="Vocabulary size for BPE training",
|
||||||
|
)
|
||||||
|
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
args = get_args()
|
||||||
|
vocab_size = args.vocab_size
|
||||||
|
lang_dir = Path(args.lang_dir)
|
||||||
|
|
||||||
|
model_type = "unigram"
|
||||||
|
|
||||||
|
model_prefix = f"{lang_dir}/{model_type}_{vocab_size}"
|
||||||
|
train_text = args.transcript
|
||||||
|
character_coverage = 1.0
|
||||||
|
input_sentence_size = 100000000
|
||||||
|
|
||||||
|
user_defined_symbols = ["<blk>", "<sos/eos>"]
|
||||||
|
unk_id = len(user_defined_symbols)
|
||||||
|
# Note: unk_id is fixed to 2.
|
||||||
|
# If you change it, you should also change other
|
||||||
|
# places that are using it.
|
||||||
|
|
||||||
|
model_file = Path(model_prefix + ".model")
|
||||||
|
if not model_file.is_file():
|
||||||
|
spm.SentencePieceTrainer.train(
|
||||||
|
input=train_text,
|
||||||
|
vocab_size=vocab_size,
|
||||||
|
model_type=model_type,
|
||||||
|
model_prefix=model_prefix,
|
||||||
|
input_sentence_size=input_sentence_size,
|
||||||
|
character_coverage=character_coverage,
|
||||||
|
user_defined_symbols=user_defined_symbols,
|
||||||
|
unk_id=unk_id,
|
||||||
|
bos_id=-1,
|
||||||
|
eos_id=-1,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
print(f"{model_file} exists - skipping")
|
||||||
|
return
|
||||||
|
|
||||||
|
shutil.copyfile(model_file, f"{lang_dir}/bpe.model")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
195
egs/ami/SURT/prepare.sh
Executable file
195
egs/ami/SURT/prepare.sh
Executable file
@ -0,0 +1,195 @@
|
|||||||
|
#!/usr/bin/env bash
|
||||||
|
|
||||||
|
set -eou pipefail
|
||||||
|
|
||||||
|
stage=-1
|
||||||
|
stop_stage=100
|
||||||
|
|
||||||
|
# We assume dl_dir (download dir) contains the following
|
||||||
|
# directories and files. If not, they will be downloaded
|
||||||
|
# by this script automatically.
|
||||||
|
#
|
||||||
|
# - $dl_dir/ami
|
||||||
|
# You can find audio and transcripts for AMI in this path.
|
||||||
|
#
|
||||||
|
# - $dl_dir/icsi
|
||||||
|
# You can find audio and transcripts for ICSI in this path.
|
||||||
|
#
|
||||||
|
# - $dl_dir/rirs_noises
|
||||||
|
# This directory contains the RIRS_NOISES corpus downloaded from https://openslr.org/28/.
|
||||||
|
#
|
||||||
|
dl_dir=$PWD/download
|
||||||
|
|
||||||
|
. shared/parse_options.sh || exit 1
|
||||||
|
|
||||||
|
# All files generated by this script are saved in "data".
|
||||||
|
# You can safely remove "data" and rerun this script to regenerate it.
|
||||||
|
mkdir -p data
|
||||||
|
vocab_size=500
|
||||||
|
|
||||||
|
log() {
|
||||||
|
# This function is from espnet
|
||||||
|
local fname=${BASH_SOURCE[1]##*/}
|
||||||
|
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
|
||||||
|
}
|
||||||
|
|
||||||
|
log "dl_dir: $dl_dir"
|
||||||
|
|
||||||
|
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
|
||||||
|
log "Stage 0: Download data"
|
||||||
|
|
||||||
|
# If you have pre-downloaded it to /path/to/amicorpus,
|
||||||
|
# you can create a symlink
|
||||||
|
#
|
||||||
|
# ln -sfv /path/to/amicorpus $dl_dir/amicorpus
|
||||||
|
#
|
||||||
|
if [ ! -d $dl_dir/amicorpus ]; then
|
||||||
|
for mic in ihm ihm-mix sdm mdm8-bf; do
|
||||||
|
lhotse download ami --mic $mic $dl_dir/amicorpus
|
||||||
|
done
|
||||||
|
fi
|
||||||
|
|
||||||
|
# If you have pre-downloaded it to /path/to/icsi,
|
||||||
|
# you can create a symlink
|
||||||
|
#
|
||||||
|
# ln -sfv /path/to/icsi $dl_dir/icsi
|
||||||
|
#
|
||||||
|
if [ ! -d $dl_dir/icsi ]; then
|
||||||
|
lhotse download icsi $dl_dir/icsi
|
||||||
|
fi
|
||||||
|
|
||||||
|
# If you have pre-downloaded it to /path/to/rirs_noises,
|
||||||
|
# you can create a symlink
|
||||||
|
#
|
||||||
|
# ln -sfv /path/to/rirs_noises $dl_dir/
|
||||||
|
#
|
||||||
|
if [ ! -d $dl_dir/rirs_noises ]; then
|
||||||
|
lhotse download rirs_noises $dl_dir
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
|
||||||
|
log "Stage 1: Prepare AMI manifests"
|
||||||
|
# We assume that you have downloaded the AMI corpus
|
||||||
|
# to $dl_dir/amicorpus. We perform text normalization for the transcripts.
|
||||||
|
mkdir -p data/manifests
|
||||||
|
for mic in ihm ihm-mix sdm mdm8-bf; do
|
||||||
|
log "Preparing AMI manifest for $mic"
|
||||||
|
lhotse prepare ami --mic $mic --max-words-per-segment 30 --merge-consecutive $dl_dir/amicorpus data/manifests/
|
||||||
|
done
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
|
||||||
|
log "Stage 2: Prepare ICSI manifests"
|
||||||
|
# We assume that you have downloaded the ICSI corpus
|
||||||
|
# to $dl_dir/icsi. We perform text normalization for the transcripts.
|
||||||
|
mkdir -p data/manifests
|
||||||
|
log "Preparing ICSI manifest"
|
||||||
|
for mic in ihm ihm-mix sdm; do
|
||||||
|
lhotse prepare icsi --mic $mic $dl_dir/icsi data/manifests/
|
||||||
|
done
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
|
||||||
|
log "Stage 3: Prepare RIRs"
|
||||||
|
# We assume that you have downloaded the RIRS_NOISES corpus
|
||||||
|
# to $dl_dir/rirs_noises
|
||||||
|
lhotse prepare rir-noise -p real_rir -p iso_noise $dl_dir/rirs_noises data/manifests
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
|
||||||
|
log "Stage 3: Extract features for AMI and ICSI recordings"
|
||||||
|
python local/compute_fbank_ami.py
|
||||||
|
python local/compute_fbank_icsi.py
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
|
||||||
|
log "Stage 5: Create sources for simulating mixtures"
|
||||||
|
# In the following script, we speed-perturb the IHM recordings and extract features.
|
||||||
|
python local/compute_fbank_ihm.py
|
||||||
|
lhotse combine data/manifests/ami-ihm_cuts_train.jsonl.gz \
|
||||||
|
data/manifests/icsi-ihm_cuts_train.jsonl.gz - |\
|
||||||
|
lhotse cut trim-to-alignments --type word --max-pause 0.5 - - |\
|
||||||
|
lhotse filter 'duration<=12.0' - - |\
|
||||||
|
shuf | gzip -c > data/manifests/ihm_cuts_train_trimmed.jsonl.gz
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
|
||||||
|
log "Stage 6: Create training mixtures"
|
||||||
|
lhotse workflows simulate-meetings \
|
||||||
|
--method conversational \
|
||||||
|
--same-spk-pause 0.5 \
|
||||||
|
--diff-spk-pause 0.5 \
|
||||||
|
--diff-spk-overlap 1.0 \
|
||||||
|
--prob-diff-spk-overlap 0.8 \
|
||||||
|
--num-meetings 200000 \
|
||||||
|
--num-speakers-per-meeting 2,3 \
|
||||||
|
--max-duration-per-speaker 15.0 \
|
||||||
|
--max-utterances-per-speaker 3 \
|
||||||
|
--seed 1234 \
|
||||||
|
--num-jobs 2 \
|
||||||
|
data/manifests/ihm_cuts_train_trimmed.jsonl.gz \
|
||||||
|
data/manifests/ai-mix_cuts_clean.jsonl.gz
|
||||||
|
|
||||||
|
python local/compute_fbank_aimix.py
|
||||||
|
|
||||||
|
# Add source features to the manifest (will be used for masking loss)
|
||||||
|
# This may take ~2 hours.
|
||||||
|
python local/add_source_feats.py
|
||||||
|
|
||||||
|
# Combine clean and reverb
|
||||||
|
cat <(gunzip -c data/manifests/cuts_train_clean_sources.jsonl.gz) \
|
||||||
|
<(gunzip -c data/manifests/cuts_train_reverb_sources.jsonl.gz) |\
|
||||||
|
shuf | gzip -c > data/manifests/cuts_train_comb_sources.jsonl.gz
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then
|
||||||
|
log "Stage 7: Create training mixtures from real sessions"
|
||||||
|
python local/prepare_ami_train_cuts.py
|
||||||
|
python local/prepare_icsi_train_cuts.py
|
||||||
|
|
||||||
|
# Combine AMI and ICSI
|
||||||
|
cat <(gunzip -c data/manifests/cuts_train_ami.jsonl.gz) \
|
||||||
|
<(gunzip -c data/manifests/cuts_train_icsi.jsonl.gz) |\
|
||||||
|
shuf | gzip -c > data/manifests/cuts_train_ami_icsi.jsonl.gz
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then
|
||||||
|
log "Stage 8: Dump transcripts for BPE model training (using AMI and ICSI)."
|
||||||
|
mkdir -p data/lm
|
||||||
|
cat <(gunzip -c data/manifests/ami-sdm_supervisions_train.jsonl.gz | jq '.text' | sed 's:"::g') \
|
||||||
|
<(gunzip -c data/manifests/icsi-sdm_supervisions_train.jsonl.gz | jq '.text' | sed 's:"::g') \
|
||||||
|
> data/lm/transcript_words.txt
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then
|
||||||
|
log "Stage 9: Prepare BPE based lang (combining AMI and ICSI)"
|
||||||
|
|
||||||
|
lang_dir=data/lang_bpe_${vocab_size}
|
||||||
|
mkdir -p $lang_dir
|
||||||
|
|
||||||
|
# Add special words to words.txt
|
||||||
|
echo "<eps> 0" > $lang_dir/words.txt
|
||||||
|
echo "!SIL 1" >> $lang_dir/words.txt
|
||||||
|
echo "<UNK> 2" >> $lang_dir/words.txt
|
||||||
|
|
||||||
|
# Add regular words to words.txt
|
||||||
|
cat data/lm/transcript_words.txt | grep -o -E '\w+' | sort -u | awk '{print $0,NR+2}' >> $lang_dir/words.txt
|
||||||
|
|
||||||
|
# Add remaining special word symbols expected by LM scripts.
|
||||||
|
num_words=$(cat $lang_dir/words.txt | wc -l)
|
||||||
|
echo "<s> ${num_words}" >> $lang_dir/words.txt
|
||||||
|
num_words=$(cat $lang_dir/words.txt | wc -l)
|
||||||
|
echo "</s> ${num_words}" >> $lang_dir/words.txt
|
||||||
|
num_words=$(cat $lang_dir/words.txt | wc -l)
|
||||||
|
echo "#0 ${num_words}" >> $lang_dir/words.txt
|
||||||
|
|
||||||
|
./local/train_bpe_model.py \
|
||||||
|
--lang-dir $lang_dir \
|
||||||
|
--vocab-size $vocab_size \
|
||||||
|
--transcript data/lm/transcript_words.txt
|
||||||
|
|
||||||
|
if [ ! -f $lang_dir/L_disambig.pt ]; then
|
||||||
|
./local/prepare_lang_bpe.py --lang-dir $lang_dir
|
||||||
|
fi
|
||||||
|
fi
|
||||||
1
egs/ami/SURT/shared
Symbolic link
1
egs/ami/SURT/shared
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../../../icefall/shared
|
||||||
Loading…
x
Reference in New Issue
Block a user