#!/usr/bin/env python3 # Copyright 2025 Xiaomi Corp. (authors: Wei Kang) # # See ../../../../LICENSE for clarification regarding multiple authors # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import argparse import logging import os from concurrent.futures import ProcessPoolExecutor as Pool from pathlib import Path from typing import Optional import lhotse import torch from feature import TorchAudioFbank, TorchAudioFbankConfig from lhotse import ( CutSet, LilcomChunkyWriter, load_manifest_lazy, set_audio_duration_mismatch_tolerance, ) # Torch's multithreaded behavior needs to be disabled or # it wastes a lot of CPU and slow things down. # Do this outside of main() in case it needs to take effect # even when we are not invoking the main (e.g. when spawning subprocesses). torch.set_num_threads(1) torch.set_num_interop_threads(1) def str2bool(v): """Used in argparse.ArgumentParser.add_argument to indicate that a type is a bool type and user can enter - yes, true, t, y, 1, to represent True - no, false, f, n, 0, to represent False See https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse # noqa """ if isinstance(v, bool): return v if v.lower() in ("yes", "true", "t", "y", "1"): return True elif v.lower() in ("no", "false", "f", "n", "0"): return False else: raise argparse.ArgumentTypeError("Boolean value expected.") def get_args(): parser = argparse.ArgumentParser() parser.add_argument( "--sampling-rate", type=int, default=24000, help="The target sampling rate, the audio will be resampled to this sampling_rate.", ) parser.add_argument( "--frame-shift", type=int, default=256, help="Frame shift in samples", ) parser.add_argument( "--frame-length", type=int, default=1024, help="Frame length in samples", ) parser.add_argument( "--num-mel-bins", type=int, default=100, help="The num of mel filters.", ) parser.add_argument( "--dataset", type=str, help="Dataset name.", ) parser.add_argument( "--subset", type=str, help="The subset of the dataset.", ) parser.add_argument( "--source-dir", type=str, default="data/manifests", help="The source directory of manifest files.", ) parser.add_argument( "--dest-dir", type=str, default="data/fbank", help="The destination directory of manifest files.", ) parser.add_argument( "--split-cuts", type=str2bool, default=False, help="Whether to use splited cuts.", ) parser.add_argument( "--split-begin", type=int, help="Start idx of splited cuts.", ) parser.add_argument( "--split-end", type=int, help="End idx of splited cuts.", ) parser.add_argument( "--batch-duration", type=int, default=1000, help="The batch duration when computing the features.", ) parser.add_argument( "--num-jobs", type=int, default=20, help="The number of extractor workers." ) return parser.parse_args() def compute_fbank_split_single(params, idx): lhotse.set_audio_duration_mismatch_tolerance(0.1) # for emilia src_dir = Path(params.source_dir) output_dir = Path(params.dest_dir) num_mel_bins = params.num_mel_bins if not src_dir.exists(): logging.error(f"{src_dir} not exists") return if not output_dir.exists(): output_dir.mkdir(parents=True, exist_ok=True) num_digits = 8 config = TorchAudioFbankConfig( sampling_rate=params.sampling_rate, n_mels=params.num_mel_bins, n_fft=params.frame_length, hop_length=params.frame_shift, ) extractor = TorchAudioFbank(config) prefix = params.dataset subset = params.subset suffix = "jsonl.gz" idx = f"{idx}".zfill(num_digits) cuts_filename = f"{prefix}_cuts_{subset}.{idx}.{suffix}" if (src_dir / cuts_filename).is_file(): logging.info(f"Loading manifests {src_dir / cuts_filename}") cut_set = load_manifest_lazy(src_dir / cuts_filename) else: logging.warning(f"Raw {cuts_filename} not exists, skipping") return cut_set = cut_set.resample(params.sampling_rate) if (output_dir / cuts_filename).is_file(): logging.info(f"{cuts_filename} already exists - skipping.") return logging.info(f"Processing {subset}.{idx} of {prefix}") cut_set = cut_set.compute_and_store_features_batch( extractor=extractor, storage_path=f"{output_dir}/{prefix}_feats_{subset}_{idx}", num_workers=4, batch_duration=params.batch_duration, storage_type=LilcomChunkyWriter, overwrite=True, ) cut_set.to_file(output_dir / cuts_filename) def compute_fbank_split(params): if params.split_end < params.split_begin: logging.warning( f"Split begin should be smaller than split end, given " f"{params.split_begin} -> {params.split_end}." ) with Pool(max_workers=params.num_jobs) as pool: futures = [ pool.submit(compute_fbank_split_single, params, i) for i in range(params.split_begin, params.split_end) ] for f in futures: f.result() f.done() def compute_fbank(params): src_dir = Path(params.source_dir) output_dir = Path(params.dest_dir) num_jobs = params.num_jobs num_mel_bins = params.num_mel_bins prefix = params.dataset subset = params.subset suffix = "jsonl.gz" cut_set_name = f"{prefix}_cuts_{subset}.{suffix}" if (src_dir / cut_set_name).is_file(): logging.info(f"Loading manifests {src_dir / cut_set_name}") cut_set = load_manifest_lazy(src_dir / cut_set_name) else: recordings = load_manifest_lazy( src_dir / f"{prefix}_recordings_{subset}.{suffix}" ) supervisions = load_manifest_lazy( src_dir / f"{prefix}_supervisions_{subset}.{suffix}" ) cut_set = CutSet.from_manifests( recordings=recordings, supervisions=supervisions, ) cut_set = cut_set.resample(params.sampling_rate) config = TorchAudioFbankConfig( sampling_rate=params.sampling_rate, n_mels=params.num_mel_bins, n_fft=params.frame_length, hop_length=params.frame_shift, ) extractor = TorchAudioFbank(config) cuts_filename = f"{prefix}_cuts_{subset}.{suffix}" if (output_dir / cuts_filename).is_file(): logging.info(f"{prefix} {subset} already exists - skipping.") return logging.info(f"Processing {subset} of {prefix}") cut_set = cut_set.compute_and_store_features( extractor=extractor, storage_path=f"{output_dir}/{prefix}_feats_{subset}", num_jobs=num_jobs, storage_type=LilcomChunkyWriter, ) cut_set.to_file(output_dir / cuts_filename) if __name__ == "__main__": formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) args = get_args() logging.info(vars(args)) if args.split_cuts: compute_fbank_split(params=args) else: compute_fbank(params=args)