icefall/egs/librispeech/ASR/local/prepare_rir.py
jaeeunbaik 915e8e399c Add CHiME-4 dataset, RIR and Self-Distillation
- Added CHiME-4 dataset integration in asr_datamodule.py
- Added Hugging Face upload script
- Added RIR augmentation
- Added Self-Distillation Training
2025-08-27 16:11:20 +09:00

139 lines
4.0 KiB
Python

#!/usr/bin/env python3
"""
Prepare RIR (Room Impulse Response) data for lhotse.
This script converts rir.scp file to lhotse manifest format.
"""
import argparse
import logging
from pathlib import Path
from typing import List
from lhotse import CutSet, Recording, SupervisionSegment
from lhotse.audio import AudioSource
from lhotse.utils import Pathlike
def get_args():
parser = argparse.ArgumentParser(
description="Prepare RIR data for lhotse",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"--rir-scp",
type=Path,
required=True,
help="Path to rir.scp file containing RIR file paths",
)
parser.add_argument(
"--output-dir",
type=Path,
required=True,
help="Output directory for RIR manifests",
)
return parser.parse_args()
def prepare_rir_manifest(
rir_scp: Pathlike,
output_dir: Pathlike,
) -> None:
"""
Prepare RIR manifest from rir.scp file.
Args:
rir_scp: Path to rir.scp file
output_dir: Output directory for manifests
"""
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
recordings = []
# Read rir.scp file
with open(rir_scp, 'r') as f:
for line_idx, line in enumerate(f):
line = line.strip()
if not line or line.startswith('#'):
continue
# Parse line: either "path" or "id path"
parts = line.split()
if len(parts) == 1:
rir_path = parts[0]
rir_id = f"rir_{line_idx:06d}"
elif len(parts) == 2:
rir_id, rir_path = parts
else:
logging.warning(f"Invalid line in rir.scp: {line}")
continue
# Check if file exists
rir_path = Path(rir_path)
if not rir_path.exists():
logging.warning(f"RIR file not found: {rir_path}")
continue
# Create recording
recording = Recording(
id=rir_id,
sources=[
AudioSource(
type="file",
channels=[0],
source=str(rir_path.resolve()),
)
],
sampling_rate=16000, # Assume 16kHz, will be auto-detected by lhotse
num_samples=None, # Will be auto-detected
duration=None, # Will be auto-detected
)
recordings.append(recording)
logging.info(f"Found {len(recordings)} RIR files")
# Create recording set and save
from lhotse import RecordingSet
recording_set = RecordingSet.from_recordings(recordings)
# Validate recordings (this will auto-detect duration, sampling_rate, etc.)
logging.info("Validating RIR recordings...")
recording_set = recording_set.with_path_prefix("") # Ensure absolute paths
# Save recording manifest
output_path = output_dir / "rir_recordings.jsonl.gz"
recording_set.to_file(output_path)
logging.info(f"Saved RIR recording manifest to {output_path}")
# Create a simple cuts manifest for RIR (whole files)
logging.info("Creating RIR cuts manifest...")
rir_cuts = CutSet.from_manifests(recordings=recording_set)
cuts_output_path = output_dir / "rir_cuts.jsonl.gz"
rir_cuts.to_file(cuts_output_path)
logging.info(f"Saved RIR cuts manifest to {cuts_output_path}")
return recording_set, rir_cuts
def main():
args = get_args()
logging.basicConfig(
format="%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s",
level=logging.INFO,
)
logging.info("Preparing RIR data...")
prepare_rir_manifest(
rir_scp=args.rir_scp,
output_dir=args.output_dir,
)
logging.info("Done!")
if __name__ == "__main__":
main()