mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
- Added CHiME-4 dataset integration in asr_datamodule.py - Added Hugging Face upload script - Added RIR augmentation - Added Self-Distillation Training
169 lines
4.9 KiB
Python
169 lines
4.9 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
This file computes fbank features of the RIR dataset.
|
|
It looks for RIR recordings and generates fbank features.
|
|
|
|
The generated fbank features are saved in data/fbank.
|
|
"""
|
|
import argparse
|
|
import logging
|
|
import os
|
|
from pathlib import Path
|
|
|
|
import torch
|
|
import soundfile as sf
|
|
from lhotse import (
|
|
CutSet,
|
|
Fbank,
|
|
FbankConfig,
|
|
LilcomChunkyWriter,
|
|
MonoCut,
|
|
RecordingSet,
|
|
Recording,
|
|
)
|
|
from lhotse.audio import AudioSource
|
|
|
|
from icefall.utils import get_executor
|
|
|
|
# Torch's multithreaded behavior needs to be disabled or
|
|
# it wastes a lot of CPU and slow things down.
|
|
torch.set_num_threads(1)
|
|
torch.set_num_interop_threads(1)
|
|
|
|
|
|
def compute_fbank_rir(
|
|
rir_scp: str = "data/manifests/rir.scp",
|
|
num_mel_bins: int = 80,
|
|
output_dir: str = "data/fbank",
|
|
max_files: int = None
|
|
):
|
|
"""
|
|
Compute fbank features for RIR files.
|
|
|
|
Args:
|
|
rir_scp: Path to rir.scp file
|
|
num_mel_bins: Number of mel filter banks
|
|
output_dir: Output directory for features
|
|
max_files: Maximum number of RIR files to process (for testing)
|
|
"""
|
|
output_dir = Path(output_dir)
|
|
num_jobs = min(15, os.cpu_count())
|
|
|
|
rir_cuts_path = output_dir / "rir_cuts.jsonl.gz"
|
|
|
|
if rir_cuts_path.is_file():
|
|
logging.info(f"{rir_cuts_path} already exists - skipping")
|
|
return
|
|
|
|
logging.info("Extracting features for RIR")
|
|
|
|
# Create RIR recordings from scp file
|
|
recordings = []
|
|
with open(rir_scp, 'r') as f:
|
|
for idx, line in enumerate(f):
|
|
if max_files and idx >= max_files:
|
|
break
|
|
|
|
rir_path = Path(line.strip())
|
|
if not rir_path.exists():
|
|
logging.warning(f"RIR file not found: {rir_path}")
|
|
continue
|
|
|
|
rir_id = f"rir_{idx:06d}"
|
|
|
|
try:
|
|
# Get audio info using soundfile
|
|
with sf.SoundFile(rir_path) as audio_file:
|
|
sampling_rate = audio_file.samplerate
|
|
num_samples = len(audio_file)
|
|
duration = num_samples / sampling_rate
|
|
|
|
# Create recording with proper metadata
|
|
recording = Recording(
|
|
id=rir_id,
|
|
sources=[
|
|
AudioSource(
|
|
type="file",
|
|
channels=[0],
|
|
source=str(rir_path.resolve()),
|
|
)
|
|
],
|
|
sampling_rate=int(sampling_rate),
|
|
num_samples=int(num_samples),
|
|
duration=float(duration),
|
|
)
|
|
recordings.append(recording)
|
|
|
|
except Exception as e:
|
|
logging.warning(f"Failed to process {rir_path}: {e}")
|
|
continue
|
|
|
|
if (idx + 1) % 1000 == 0:
|
|
logging.info(f"Processed {idx + 1} RIR files...")
|
|
|
|
logging.info(f"Found {len(recordings)} RIR files")
|
|
|
|
# Create recording set
|
|
rir_recordings = RecordingSet.from_recordings(recordings)
|
|
|
|
# Feature extractor
|
|
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
|
|
|
|
with get_executor() as ex:
|
|
# Create cuts and compute features
|
|
rir_cuts = (
|
|
CutSet.from_manifests(recordings=rir_recordings)
|
|
.compute_and_store_features(
|
|
extractor=extractor,
|
|
storage_path=f"{output_dir}/rir_feats",
|
|
num_jobs=num_jobs if ex is None else 80,
|
|
executor=ex,
|
|
storage_type=LilcomChunkyWriter,
|
|
)
|
|
)
|
|
rir_cuts.to_file(rir_cuts_path)
|
|
|
|
logging.info(f"Saved RIR cuts with features to {rir_cuts_path}")
|
|
|
|
|
|
def get_args():
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument(
|
|
"--rir-scp",
|
|
type=str,
|
|
default="data/manifests/rir.scp",
|
|
help="Path to rir.scp file. Default: data/manifests/rir.scp",
|
|
)
|
|
parser.add_argument(
|
|
"--num-mel-bins",
|
|
type=int,
|
|
default=80,
|
|
help="The number of mel bins for Fbank. Default: 80",
|
|
)
|
|
parser.add_argument(
|
|
"--output-dir",
|
|
type=str,
|
|
default="data/fbank",
|
|
help="Output directory. Default: data/fbank",
|
|
)
|
|
parser.add_argument(
|
|
"--max-files",
|
|
type=int,
|
|
default=None,
|
|
help="Maximum number of RIR files to process (for testing). Default: None (process all)",
|
|
)
|
|
return parser.parse_args()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
|
|
|
logging.basicConfig(format=formatter, level=logging.INFO)
|
|
args = get_args()
|
|
compute_fbank_rir(
|
|
rir_scp=args.rir_scp,
|
|
num_mel_bins=args.num_mel_bins,
|
|
output_dir=args.output_dir,
|
|
max_files=args.max_files,
|
|
)
|