mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-05 07:04:18 +00:00
Fix black
This commit is contained in:
parent
737125f52b
commit
ddc52d5839
@ -21,7 +21,14 @@ import logging
|
|||||||
import torch
|
import torch
|
||||||
import lhotse
|
import lhotse
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter, fix_manifests, validate_recordings_and_supervisions
|
from lhotse import (
|
||||||
|
CutSet,
|
||||||
|
Fbank,
|
||||||
|
FbankConfig,
|
||||||
|
LilcomChunkyWriter,
|
||||||
|
fix_manifests,
|
||||||
|
validate_recordings_and_supervisions,
|
||||||
|
)
|
||||||
from icefall.utils import get_executor, str2bool
|
from icefall.utils import get_executor, str2bool
|
||||||
|
|
||||||
# Torch's multithreaded behavior needs to be disabled or
|
# Torch's multithreaded behavior needs to be disabled or
|
||||||
@ -31,6 +38,7 @@ from icefall.utils import get_executor, str2bool
|
|||||||
torch.set_num_threads(1)
|
torch.set_num_threads(1)
|
||||||
torch.set_num_interop_threads(1)
|
torch.set_num_interop_threads(1)
|
||||||
|
|
||||||
|
|
||||||
def get_args():
|
def get_args():
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
|
|
||||||
@ -79,10 +87,7 @@ def get_args():
|
|||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--num-jobs",
|
"--num-jobs", type=int, default=50, help="The num of jobs to extract feature."
|
||||||
type=int,
|
|
||||||
default=50,
|
|
||||||
help="The num of jobs to extract feature."
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
@ -109,11 +114,7 @@ def compute_feature(args, cuts):
|
|||||||
if "train" in args.partition:
|
if "train" in args.partition:
|
||||||
if args.perturb_speed:
|
if args.perturb_speed:
|
||||||
logging.info(f"Doing speed perturb")
|
logging.info(f"Doing speed perturb")
|
||||||
cuts = (
|
cuts = cuts + cuts.perturb_speed(0.9) + cuts.perturb_speed(1.1)
|
||||||
cuts
|
|
||||||
+ cuts.perturb_speed(0.9)
|
|
||||||
+ cuts.perturb_speed(1.1)
|
|
||||||
)
|
|
||||||
cuts = cuts.compute_and_store_features(
|
cuts = cuts.compute_and_store_features(
|
||||||
extractor=extractor,
|
extractor=extractor,
|
||||||
storage_path=f"{args.output_dir}/{args.dataset}_feats_{args.partition}",
|
storage_path=f"{args.output_dir}/{args.dataset}_feats_{args.partition}",
|
||||||
@ -132,7 +133,7 @@ def main(args):
|
|||||||
compute_feature(args, cuts)
|
compute_feature(args, cuts)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||||
|
|
||||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user