mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
Merge 1775f36a612b4bbc351077d51c4e40540678d61f into 34fc1fdf0d8ff520e2bb18267d046ca207c78ef9
This commit is contained in:
commit
7613a5e40c
@ -68,6 +68,13 @@ def get_args():
|
|||||||
help="""Perturb speed with factor 0.9 and 1.1 on train subset.""",
|
help="""Perturb speed with factor 0.9 and 1.1 on train subset.""",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--num-workers",
|
||||||
|
type=int,
|
||||||
|
default=15,
|
||||||
|
help="Number of worker processes for feature extraction.",
|
||||||
|
)
|
||||||
|
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
@ -75,10 +82,11 @@ def compute_fbank_librispeech(
|
|||||||
bpe_model: Optional[str] = None,
|
bpe_model: Optional[str] = None,
|
||||||
dataset: Optional[str] = None,
|
dataset: Optional[str] = None,
|
||||||
perturb_speed: Optional[bool] = True,
|
perturb_speed: Optional[bool] = True,
|
||||||
|
num_workers: int = 15,
|
||||||
):
|
):
|
||||||
src_dir = Path("data/manifests")
|
src_dir = Path("data/manifests")
|
||||||
output_dir = Path("data/fbank")
|
output_dir = Path("data/fbank")
|
||||||
num_jobs = min(15, os.cpu_count())
|
num_jobs = min(num_workers, os.cpu_count())
|
||||||
num_mel_bins = 80
|
num_mel_bins = 80
|
||||||
|
|
||||||
if bpe_model:
|
if bpe_model:
|
||||||
@ -125,6 +133,7 @@ def compute_fbank_librispeech(
|
|||||||
logging.info(f"{partition} already exists - skipping.")
|
logging.info(f"{partition} already exists - skipping.")
|
||||||
continue
|
continue
|
||||||
logging.info(f"Processing {partition}")
|
logging.info(f"Processing {partition}")
|
||||||
|
|
||||||
cut_set = CutSet.from_manifests(
|
cut_set = CutSet.from_manifests(
|
||||||
recordings=m["recordings"],
|
recordings=m["recordings"],
|
||||||
supervisions=m["supervisions"],
|
supervisions=m["supervisions"],
|
||||||
@ -134,20 +143,44 @@ def compute_fbank_librispeech(
|
|||||||
if bpe_model:
|
if bpe_model:
|
||||||
cut_set = filter_cuts(cut_set, sp)
|
cut_set = filter_cuts(cut_set, sp)
|
||||||
if perturb_speed:
|
if perturb_speed:
|
||||||
logging.info(f"Doing speed perturb")
|
logging.info("Doing speed perturb")
|
||||||
cut_set = (
|
cut_set = (
|
||||||
cut_set
|
cut_set
|
||||||
+ cut_set.perturb_speed(0.9)
|
+ cut_set.perturb_speed(0.9)
|
||||||
+ cut_set.perturb_speed(1.1)
|
+ cut_set.perturb_speed(1.1)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if ex is None:
|
||||||
|
# Create a custom process pool context for None (local execution)
|
||||||
|
import multiprocessing as mp
|
||||||
|
from concurrent.futures import ProcessPoolExecutor
|
||||||
|
|
||||||
|
# Calculate the number of jobs
|
||||||
|
actual_jobs = (
|
||||||
|
min(num_jobs * 2, 20) if "train" in partition else num_jobs
|
||||||
|
)
|
||||||
|
|
||||||
|
# Use the forkserver method
|
||||||
|
ctx = mp.get_context("forkserver")
|
||||||
|
with ProcessPoolExecutor(
|
||||||
|
max_workers=actual_jobs, mp_context=ctx
|
||||||
|
) as local_executor:
|
||||||
cut_set = cut_set.compute_and_store_features(
|
cut_set = cut_set.compute_and_store_features(
|
||||||
extractor=extractor,
|
extractor=extractor,
|
||||||
storage_path=f"{output_dir}/{prefix}_feats_{partition}",
|
storage_path=f"{output_dir}/{prefix}_feats_{partition}",
|
||||||
# when an executor is specified, make more partitions
|
executor=local_executor,
|
||||||
num_jobs=num_jobs if ex is None else 80,
|
storage_type=LilcomChunkyWriter,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Distributed environment, use the provided executor
|
||||||
|
cut_set = cut_set.compute_and_store_features(
|
||||||
|
extractor=extractor,
|
||||||
|
storage_path=f"{output_dir}/{prefix}_feats_{partition}",
|
||||||
|
num_jobs=min(num_jobs * 2, 20),
|
||||||
executor=ex,
|
executor=ex,
|
||||||
storage_type=LilcomChunkyWriter,
|
storage_type=LilcomChunkyWriter,
|
||||||
)
|
)
|
||||||
|
|
||||||
cut_set.to_file(output_dir / cuts_filename)
|
cut_set.to_file(output_dir / cuts_filename)
|
||||||
|
|
||||||
|
|
||||||
@ -161,4 +194,5 @@ if __name__ == "__main__":
|
|||||||
bpe_model=args.bpe_model,
|
bpe_model=args.bpe_model,
|
||||||
dataset=args.dataset,
|
dataset=args.dataset,
|
||||||
perturb_speed=args.perturb_speed,
|
perturb_speed=args.perturb_speed,
|
||||||
|
num_workers=args.num_workers,
|
||||||
)
|
)
|
||||||
|
@ -139,7 +139,7 @@ if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
|
|||||||
log "Stage 3: Compute fbank for librispeech"
|
log "Stage 3: Compute fbank for librispeech"
|
||||||
mkdir -p data/fbank
|
mkdir -p data/fbank
|
||||||
if [ ! -e data/fbank/.librispeech.done ]; then
|
if [ ! -e data/fbank/.librispeech.done ]; then
|
||||||
./local/compute_fbank_librispeech.py
|
./local/compute_fbank_librispeech.py --num-workers $nj
|
||||||
touch data/fbank/.librispeech.done
|
touch data/fbank/.librispeech.done
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
@ -180,10 +180,10 @@ def setup_logger(
|
|||||||
world_size = dist.get_world_size()
|
world_size = dist.get_world_size()
|
||||||
rank = dist.get_rank()
|
rank = dist.get_rank()
|
||||||
formatter = f"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] ({rank}/{world_size}) %(message)s" # noqa
|
formatter = f"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] ({rank}/{world_size}) %(message)s" # noqa
|
||||||
log_filename = f"{log_filename}-{date_time}-{rank}"
|
log_filename = f"{log_filename}-{date_time}-{rank}.log"
|
||||||
else:
|
else:
|
||||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||||
log_filename = f"{log_filename}-{date_time}"
|
log_filename = f"{log_filename}-{date_time}.log"
|
||||||
|
|
||||||
os.makedirs(os.path.dirname(log_filename), exist_ok=True)
|
os.makedirs(os.path.dirname(log_filename), exist_ok=True)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user