mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 10:16:14 +00:00
Reformat by black
This commit is contained in:
parent
ee21954c15
commit
db5c61d371
@ -63,7 +63,7 @@ def get_args():
|
||||
parser.add_argument(
|
||||
"--data-dir",
|
||||
type=str,
|
||||
default='data',
|
||||
default="data",
|
||||
help="""Path of data directory""",
|
||||
)
|
||||
|
||||
@ -74,7 +74,7 @@ def compute_fbank_speechtools(
|
||||
bpe_model: Optional[str] = None,
|
||||
dataset: Optional[str] = None,
|
||||
perturb_speed: Optional[bool] = False,
|
||||
data_dir: Optional[str] = 'data',
|
||||
data_dir: Optional[str] = "data",
|
||||
):
|
||||
src_dir = Path(data_dir) / "manifests"
|
||||
output_dir = Path(data_dir) / "fbank"
|
||||
@ -116,9 +116,9 @@ def compute_fbank_speechtools(
|
||||
|
||||
if torch.cuda.is_available():
|
||||
# Use cuda for fbank compute
|
||||
device = 'cuda'
|
||||
device = "cuda"
|
||||
else:
|
||||
device = 'cpu'
|
||||
device = "cpu"
|
||||
logging.info(f"Device: {device}")
|
||||
|
||||
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins, device=device))
|
||||
@ -137,7 +137,9 @@ def compute_fbank_speechtools(
|
||||
)
|
||||
|
||||
# Filter duration
|
||||
cut_set = cut_set.filter(lambda x: x.duration > 1 and x.sampling_rate == 16000)
|
||||
cut_set = cut_set.filter(
|
||||
lambda x: x.duration > 1 and x.sampling_rate == 16000
|
||||
)
|
||||
|
||||
if "train" in partition:
|
||||
if bpe_model:
|
||||
@ -150,7 +152,7 @@ def compute_fbank_speechtools(
|
||||
+ cut_set.perturb_speed(1.1)
|
||||
)
|
||||
logging.info(f"Compute & Store features...")
|
||||
if device == 'cuda':
|
||||
if device == "cuda":
|
||||
cut_set = cut_set.compute_and_store_features_batch(
|
||||
extractor=extractor,
|
||||
storage_path=f"{output_dir}/{prefix}_feats_{partition}",
|
||||
|
@ -54,7 +54,7 @@ def compute_fbank_musan(
|
||||
src_dir: str = "data/manifests",
|
||||
num_mel_bins: int = 80,
|
||||
whisper_fbank: bool = False,
|
||||
output_dir: str = "data/fbank"
|
||||
output_dir: str = "data/fbank",
|
||||
):
|
||||
src_dir = Path(src_dir)
|
||||
output_dir = Path(output_dir)
|
||||
|
@ -693,7 +693,11 @@ def save_results(
|
||||
errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt"
|
||||
with open(errs_filename, "w") as f:
|
||||
cer = write_error_stats(
|
||||
f, f"{test_set_name}-{key}", results, enable_log=True, compute_CER=True,
|
||||
f,
|
||||
f"{test_set_name}-{key}",
|
||||
results,
|
||||
enable_log=True,
|
||||
compute_CER=True,
|
||||
)
|
||||
test_set_cers[key] = cer
|
||||
|
||||
|
@ -322,7 +322,6 @@ def decode_dataset(
|
||||
high_freq=-400.0,
|
||||
)
|
||||
|
||||
|
||||
log_interval = 50
|
||||
|
||||
decode_results = []
|
||||
@ -426,7 +425,11 @@ def save_results(
|
||||
errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt"
|
||||
with open(errs_filename, "w") as f:
|
||||
cer = write_error_stats(
|
||||
f, f"{test_set_name}-{key}", results, enable_log=True, compute_CER=True,
|
||||
f,
|
||||
f"{test_set_name}-{key}",
|
||||
results,
|
||||
enable_log=True,
|
||||
compute_CER=True,
|
||||
)
|
||||
test_set_cers[key] = cer
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user