diff --git a/egs/wenetspeech/ASR/local/prepare_dataset_from_kaldi_dir.py b/egs/wenetspeech/ASR/local/prepare_dataset_from_kaldi_dir.py index 8412815b1..334a6d023 100644 --- a/egs/wenetspeech/ASR/local/prepare_dataset_from_kaldi_dir.py +++ b/egs/wenetspeech/ASR/local/prepare_dataset_from_kaldi_dir.py @@ -21,7 +21,14 @@ import logging import torch import lhotse from pathlib import Path -from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter, fix_manifests, validate_recordings_and_supervisions +from lhotse import ( + CutSet, + Fbank, + FbankConfig, + LilcomChunkyWriter, + fix_manifests, + validate_recordings_and_supervisions, +) from icefall.utils import get_executor, str2bool # Torch's multithreaded behavior needs to be disabled or @@ -31,6 +38,7 @@ from icefall.utils import get_executor, str2bool torch.set_num_threads(1) torch.set_num_interop_threads(1) + def get_args(): parser = argparse.ArgumentParser() @@ -79,10 +87,7 @@ def get_args(): ) parser.add_argument( - "--num-jobs", - type=int, - default=50, - help="The num of jobs to extract feature." + "--num-jobs", type=int, default=50, help="The num of jobs to extract feature." ) return parser.parse_args() @@ -109,11 +114,7 @@ def compute_feature(args, cuts): if "train" in args.partition: if args.perturb_speed: logging.info(f"Doing speed perturb") - cuts = ( - cuts - + cuts.perturb_speed(0.9) - + cuts.perturb_speed(1.1) - ) + cuts = cuts + cuts.perturb_speed(0.9) + cuts.perturb_speed(1.1) cuts = cuts.compute_and_store_features( extractor=extractor, storage_path=f"{args.output_dir}/{args.dataset}_feats_{args.partition}", @@ -132,7 +133,7 @@ def main(args): compute_feature(args, cuts) -if __name__ == '__main__': +if __name__ == "__main__": formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO)