mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
* remove unnecessary changes * add AMI prepare scripts * add zipformer scripts for AMI * added logs and pretrained model * minor fix * remove unwanted changes * fix missing link * make suggested changes * update results
159 lines
5.2 KiB
Python
159 lines
5.2 KiB
Python
#!/usr/local/bin/python
|
|
# -*- coding: utf-8 -*-
|
|
# Data preparation for AMI GSS-enhanced dataset.
|
|
|
|
import logging
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
from pathlib import Path
|
|
|
|
from lhotse import Recording, RecordingSet, SupervisionSet
|
|
from lhotse.qa import fix_manifests
|
|
from lhotse.recipes.utils import read_manifests_if_cached
|
|
from lhotse.utils import fastcopy
|
|
from tqdm import tqdm
|
|
|
|
logging.basicConfig(
|
|
format="%(asctime)s %(levelname)-8s %(message)s",
|
|
level=logging.INFO,
|
|
datefmt="%Y-%m-%d %H:%M:%S",
|
|
)
|
|
|
|
|
|
def get_args():
|
|
import argparse
|
|
|
|
parser = argparse.ArgumentParser(description="AMI enhanced dataset preparation.")
|
|
parser.add_argument(
|
|
"manifests_dir",
|
|
type=Path,
|
|
help="Path to directory containing AMI manifests.",
|
|
)
|
|
parser.add_argument(
|
|
"enhanced_dir",
|
|
type=Path,
|
|
help="Path to enhanced data directory.",
|
|
)
|
|
parser.add_argument(
|
|
"--num-jobs",
|
|
"-j",
|
|
type=int,
|
|
default=1,
|
|
help="Number of parallel jobs to run.",
|
|
)
|
|
parser.add_argument(
|
|
"--min-segment-duration",
|
|
"-d",
|
|
type=float,
|
|
default=0.0,
|
|
help="Minimum duration of a segment in seconds.",
|
|
)
|
|
return parser.parse_args()
|
|
|
|
|
|
def find_recording_and_create_new_supervision(enhanced_dir, supervision):
|
|
"""
|
|
Given a supervision (corresponding to original AMI recording), this function finds the
|
|
enhanced recording correspoding to the supervision, and returns this recording and
|
|
a new supervision whose start and end times are adjusted to match the enhanced recording.
|
|
"""
|
|
file_name = Path(
|
|
f"{supervision.recording_id}-{supervision.speaker}-{int(100*supervision.start):06d}_{int(100*supervision.end):06d}.flac"
|
|
)
|
|
save_path = enhanced_dir / f"{supervision.recording_id}" / file_name
|
|
if save_path.exists():
|
|
recording = Recording.from_file(save_path)
|
|
if recording.duration == 0:
|
|
logging.warning(f"Skipping {save_path} which has duration 0 seconds.")
|
|
return None
|
|
|
|
# Old supervision is wrt to the original recording, we create new supervision
|
|
# wrt to the enhanced segment
|
|
new_supervision = fastcopy(
|
|
supervision,
|
|
recording_id=recording.id,
|
|
start=0,
|
|
duration=recording.duration,
|
|
)
|
|
return recording, new_supervision
|
|
else:
|
|
logging.warning(f"{save_path} does not exist.")
|
|
return None
|
|
|
|
|
|
def main(args):
|
|
# Get arguments
|
|
manifests_dir = args.manifests_dir
|
|
enhanced_dir = args.enhanced_dir
|
|
|
|
# Load manifests from cache if they exist (saves time)
|
|
manifests = read_manifests_if_cached(
|
|
dataset_parts=["train", "dev", "test"],
|
|
output_dir=manifests_dir,
|
|
prefix="ami-sdm",
|
|
suffix="jsonl.gz",
|
|
)
|
|
if not manifests:
|
|
raise ValueError("AMI SDM manifests not found in {}".format(manifests_dir))
|
|
|
|
with ThreadPoolExecutor(args.num_jobs) as ex:
|
|
for part in ["train", "dev", "test"]:
|
|
logging.info(f"Processing {part}...")
|
|
supervisions_orig = manifests[part]["supervisions"].filter(
|
|
lambda s: s.duration >= args.min_segment_duration
|
|
)
|
|
# Remove TS3009d supervisions since they are not present in the enhanced data
|
|
supervisions_orig = supervisions_orig.filter(
|
|
lambda s: s.recording_id != "TS3009d"
|
|
)
|
|
futures = []
|
|
|
|
for supervision in tqdm(
|
|
supervisions_orig,
|
|
desc="Distributing tasks",
|
|
):
|
|
futures.append(
|
|
ex.submit(
|
|
find_recording_and_create_new_supervision,
|
|
enhanced_dir,
|
|
supervision,
|
|
)
|
|
)
|
|
|
|
recordings = []
|
|
supervisions = []
|
|
for future in tqdm(
|
|
futures,
|
|
total=len(futures),
|
|
desc="Processing tasks",
|
|
):
|
|
result = future.result()
|
|
if result is not None:
|
|
recording, new_supervision = result
|
|
recordings.append(recording)
|
|
supervisions.append(new_supervision)
|
|
|
|
# Remove duplicates from the recordings
|
|
recordings_nodup = {}
|
|
for recording in recordings:
|
|
if recording.id not in recordings_nodup:
|
|
recordings_nodup[recording.id] = recording
|
|
else:
|
|
logging.warning("Recording {} is duplicated.".format(recording.id))
|
|
recordings = RecordingSet.from_recordings(recordings_nodup.values())
|
|
supervisions = SupervisionSet.from_segments(supervisions)
|
|
|
|
recordings, supervisions = fix_manifests(
|
|
recordings=recordings, supervisions=supervisions
|
|
)
|
|
|
|
logging.info(f"Writing {part} enhanced manifests")
|
|
recordings.to_file(manifests_dir / f"ami-gss_recordings_{part}.jsonl.gz")
|
|
supervisions.to_file(
|
|
manifests_dir / f"ami-gss_supervisions_{part}.jsonl.gz"
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
args = get_args()
|
|
main(args)
|