Merge 1775f36a612b4bbc351077d51c4e40540678d61f into 34fc1fdf0d8ff520e2bb18267d046ca207c78ef9

This commit is contained in:
RedSheep 2025-07-25 09:16:26 +02:00 committed by GitHub
commit 7613a5e40c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 47 additions and 13 deletions

View File

@ -68,6 +68,13 @@ def get_args():
help="""Perturb speed with factor 0.9 and 1.1 on train subset.""", help="""Perturb speed with factor 0.9 and 1.1 on train subset.""",
) )
parser.add_argument(
"--num-workers",
type=int,
default=15,
help="Number of worker processes for feature extraction.",
)
return parser.parse_args() return parser.parse_args()
@ -75,10 +82,11 @@ def compute_fbank_librispeech(
bpe_model: Optional[str] = None, bpe_model: Optional[str] = None,
dataset: Optional[str] = None, dataset: Optional[str] = None,
perturb_speed: Optional[bool] = True, perturb_speed: Optional[bool] = True,
num_workers: int = 15,
): ):
src_dir = Path("data/manifests") src_dir = Path("data/manifests")
output_dir = Path("data/fbank") output_dir = Path("data/fbank")
num_jobs = min(15, os.cpu_count()) num_jobs = min(num_workers, os.cpu_count())
num_mel_bins = 80 num_mel_bins = 80
if bpe_model: if bpe_model:
@ -125,6 +133,7 @@ def compute_fbank_librispeech(
logging.info(f"{partition} already exists - skipping.") logging.info(f"{partition} already exists - skipping.")
continue continue
logging.info(f"Processing {partition}") logging.info(f"Processing {partition}")
cut_set = CutSet.from_manifests( cut_set = CutSet.from_manifests(
recordings=m["recordings"], recordings=m["recordings"],
supervisions=m["supervisions"], supervisions=m["supervisions"],
@ -134,20 +143,44 @@ def compute_fbank_librispeech(
if bpe_model: if bpe_model:
cut_set = filter_cuts(cut_set, sp) cut_set = filter_cuts(cut_set, sp)
if perturb_speed: if perturb_speed:
logging.info(f"Doing speed perturb") logging.info("Doing speed perturb")
cut_set = ( cut_set = (
cut_set cut_set
+ cut_set.perturb_speed(0.9) + cut_set.perturb_speed(0.9)
+ cut_set.perturb_speed(1.1) + cut_set.perturb_speed(1.1)
) )
cut_set = cut_set.compute_and_store_features(
extractor=extractor, if ex is None:
storage_path=f"{output_dir}/{prefix}_feats_{partition}", # Create a custom process pool context for None (local execution)
# when an executor is specified, make more partitions import multiprocessing as mp
num_jobs=num_jobs if ex is None else 80, from concurrent.futures import ProcessPoolExecutor
executor=ex,
storage_type=LilcomChunkyWriter, # Calculate the number of jobs
) actual_jobs = (
min(num_jobs * 2, 20) if "train" in partition else num_jobs
)
# Use the forkserver method
ctx = mp.get_context("forkserver")
with ProcessPoolExecutor(
max_workers=actual_jobs, mp_context=ctx
) as local_executor:
cut_set = cut_set.compute_and_store_features(
extractor=extractor,
storage_path=f"{output_dir}/{prefix}_feats_{partition}",
executor=local_executor,
storage_type=LilcomChunkyWriter,
)
else:
# Distributed environment, use the provided executor
cut_set = cut_set.compute_and_store_features(
extractor=extractor,
storage_path=f"{output_dir}/{prefix}_feats_{partition}",
num_jobs=min(num_jobs * 2, 20),
executor=ex,
storage_type=LilcomChunkyWriter,
)
cut_set.to_file(output_dir / cuts_filename) cut_set.to_file(output_dir / cuts_filename)
@ -161,4 +194,5 @@ if __name__ == "__main__":
bpe_model=args.bpe_model, bpe_model=args.bpe_model,
dataset=args.dataset, dataset=args.dataset,
perturb_speed=args.perturb_speed, perturb_speed=args.perturb_speed,
num_workers=args.num_workers,
) )

View File

@ -139,7 +139,7 @@ if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
log "Stage 3: Compute fbank for librispeech" log "Stage 3: Compute fbank for librispeech"
mkdir -p data/fbank mkdir -p data/fbank
if [ ! -e data/fbank/.librispeech.done ]; then if [ ! -e data/fbank/.librispeech.done ]; then
./local/compute_fbank_librispeech.py ./local/compute_fbank_librispeech.py --num-workers $nj
touch data/fbank/.librispeech.done touch data/fbank/.librispeech.done
fi fi

View File

@ -180,10 +180,10 @@ def setup_logger(
world_size = dist.get_world_size() world_size = dist.get_world_size()
rank = dist.get_rank() rank = dist.get_rank()
formatter = f"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] ({rank}/{world_size}) %(message)s" # noqa formatter = f"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] ({rank}/{world_size}) %(message)s" # noqa
log_filename = f"{log_filename}-{date_time}-{rank}" log_filename = f"{log_filename}-{date_time}-{rank}.log"
else: else:
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
log_filename = f"{log_filename}-{date_time}" log_filename = f"{log_filename}-{date_time}.log"
os.makedirs(os.path.dirname(log_filename), exist_ok=True) os.makedirs(os.path.dirname(log_filename), exist_ok=True)