icefall/egs/librispeech/ASR/local/compute_fbank_rir.py
jaeeunbaik 915e8e399c Add CHiME-4 dataset, RIR and Self-Distillation
- Added CHiME-4 dataset integration in asr_datamodule.py
- Added Hugging Face upload script
- Added RIR augmentation
- Added Self-Distillation Training
2025-08-27 16:11:20 +09:00

169 lines
4.9 KiB
Python

#!/usr/bin/env python3
"""
This file computes fbank features of the RIR dataset.
It looks for RIR recordings and generates fbank features.
The generated fbank features are saved in data/fbank.
"""
import argparse
import logging
import os
from pathlib import Path
import torch
import soundfile as sf
from lhotse import (
CutSet,
Fbank,
FbankConfig,
LilcomChunkyWriter,
MonoCut,
RecordingSet,
Recording,
)
from lhotse.audio import AudioSource
from icefall.utils import get_executor
# Torch's multithreaded behavior needs to be disabled or
# it wastes a lot of CPU and slow things down.
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
def compute_fbank_rir(
rir_scp: str = "data/manifests/rir.scp",
num_mel_bins: int = 80,
output_dir: str = "data/fbank",
max_files: int = None
):
"""
Compute fbank features for RIR files.
Args:
rir_scp: Path to rir.scp file
num_mel_bins: Number of mel filter banks
output_dir: Output directory for features
max_files: Maximum number of RIR files to process (for testing)
"""
output_dir = Path(output_dir)
num_jobs = min(15, os.cpu_count())
rir_cuts_path = output_dir / "rir_cuts.jsonl.gz"
if rir_cuts_path.is_file():
logging.info(f"{rir_cuts_path} already exists - skipping")
return
logging.info("Extracting features for RIR")
# Create RIR recordings from scp file
recordings = []
with open(rir_scp, 'r') as f:
for idx, line in enumerate(f):
if max_files and idx >= max_files:
break
rir_path = Path(line.strip())
if not rir_path.exists():
logging.warning(f"RIR file not found: {rir_path}")
continue
rir_id = f"rir_{idx:06d}"
try:
# Get audio info using soundfile
with sf.SoundFile(rir_path) as audio_file:
sampling_rate = audio_file.samplerate
num_samples = len(audio_file)
duration = num_samples / sampling_rate
# Create recording with proper metadata
recording = Recording(
id=rir_id,
sources=[
AudioSource(
type="file",
channels=[0],
source=str(rir_path.resolve()),
)
],
sampling_rate=int(sampling_rate),
num_samples=int(num_samples),
duration=float(duration),
)
recordings.append(recording)
except Exception as e:
logging.warning(f"Failed to process {rir_path}: {e}")
continue
if (idx + 1) % 1000 == 0:
logging.info(f"Processed {idx + 1} RIR files...")
logging.info(f"Found {len(recordings)} RIR files")
# Create recording set
rir_recordings = RecordingSet.from_recordings(recordings)
# Feature extractor
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
with get_executor() as ex:
# Create cuts and compute features
rir_cuts = (
CutSet.from_manifests(recordings=rir_recordings)
.compute_and_store_features(
extractor=extractor,
storage_path=f"{output_dir}/rir_feats",
num_jobs=num_jobs if ex is None else 80,
executor=ex,
storage_type=LilcomChunkyWriter,
)
)
rir_cuts.to_file(rir_cuts_path)
logging.info(f"Saved RIR cuts with features to {rir_cuts_path}")
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--rir-scp",
type=str,
default="data/manifests/rir.scp",
help="Path to rir.scp file. Default: data/manifests/rir.scp",
)
parser.add_argument(
"--num-mel-bins",
type=int,
default=80,
help="The number of mel bins for Fbank. Default: 80",
)
parser.add_argument(
"--output-dir",
type=str,
default="data/fbank",
help="Output directory. Default: data/fbank",
)
parser.add_argument(
"--max-files",
type=int,
default=None,
help="Maximum number of RIR files to process (for testing). Default: None (process all)",
)
return parser.parse_args()
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
args = get_args()
compute_fbank_rir(
rir_scp=args.rir_scp,
num_mel_bins=args.num_mel_bins,
output_dir=args.output_dir,
max_files=args.max_files,
)