Reformat by black

This commit is contained in:
whsqkaak 2024-06-13 15:02:59 +09:00
parent ee21954c15
commit db5c61d371
6 changed files with 27 additions and 18 deletions

View File

@ -63,7 +63,7 @@ def get_args():
parser.add_argument( parser.add_argument(
"--data-dir", "--data-dir",
type=str, type=str,
default='data', default="data",
help="""Path of data directory""", help="""Path of data directory""",
) )
@ -74,10 +74,10 @@ def compute_fbank_speechtools(
bpe_model: Optional[str] = None, bpe_model: Optional[str] = None,
dataset: Optional[str] = None, dataset: Optional[str] = None,
perturb_speed: Optional[bool] = False, perturb_speed: Optional[bool] = False,
data_dir: Optional[str] = 'data', data_dir: Optional[str] = "data",
): ):
src_dir = Path(data_dir) / "manifests" src_dir = Path(data_dir) / "manifests"
output_dir = Path(data_dir ) / "fbank" output_dir = Path(data_dir) / "fbank"
num_jobs = min(4, os.cpu_count()) num_jobs = min(4, os.cpu_count())
num_mel_bins = 80 num_mel_bins = 80
@ -116,9 +116,9 @@ def compute_fbank_speechtools(
if torch.cuda.is_available(): if torch.cuda.is_available():
# Use cuda for fbank compute # Use cuda for fbank compute
device = 'cuda' device = "cuda"
else: else:
device = 'cpu' device = "cpu"
logging.info(f"Device: {device}") logging.info(f"Device: {device}")
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins, device=device)) extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins, device=device))
@ -137,7 +137,9 @@ def compute_fbank_speechtools(
) )
# Filter duration # Filter duration
cut_set = cut_set.filter(lambda x: x.duration > 1 and x.sampling_rate == 16000) cut_set = cut_set.filter(
lambda x: x.duration > 1 and x.sampling_rate == 16000
)
if "train" in partition: if "train" in partition:
if bpe_model: if bpe_model:
@ -150,7 +152,7 @@ def compute_fbank_speechtools(
+ cut_set.perturb_speed(1.1) + cut_set.perturb_speed(1.1)
) )
logging.info(f"Compute & Store features...") logging.info(f"Compute & Store features...")
if device == 'cuda': if device == "cuda":
cut_set = cut_set.compute_and_store_features_batch( cut_set = cut_set.compute_and_store_features_batch(
extractor=extractor, extractor=extractor,
storage_path=f"{output_dir}/{prefix}_feats_{partition}", storage_path=f"{output_dir}/{prefix}_feats_{partition}",

View File

@ -54,7 +54,7 @@ def compute_fbank_musan(
src_dir: str = "data/manifests", src_dir: str = "data/manifests",
num_mel_bins: int = 80, num_mel_bins: int = 80,
whisper_fbank: bool = False, whisper_fbank: bool = False,
output_dir: str = "data/fbank" output_dir: str = "data/fbank",
): ):
src_dir = Path(src_dir) src_dir = Path(src_dir)
output_dir = Path(output_dir) output_dir = Path(output_dir)

View File

@ -693,7 +693,11 @@ def save_results(
errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt"
with open(errs_filename, "w") as f: with open(errs_filename, "w") as f:
cer = write_error_stats( cer = write_error_stats(
f, f"{test_set_name}-{key}", results, enable_log=True, compute_CER=True, f,
f"{test_set_name}-{key}",
results,
enable_log=True,
compute_CER=True,
) )
test_set_cers[key] = cer test_set_cers[key] = cer

View File

@ -322,7 +322,6 @@ def decode_dataset(
high_freq=-400.0, high_freq=-400.0,
) )
log_interval = 50 log_interval = 50
decode_results = [] decode_results = []
@ -426,7 +425,11 @@ def save_results(
errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt"
with open(errs_filename, "w") as f: with open(errs_filename, "w") as f:
cer = write_error_stats( cer = write_error_stats(
f, f"{test_set_name}-{key}", results, enable_log=True, compute_CER=True, f,
f"{test_set_name}-{key}",
results,
enable_log=True,
compute_CER=True,
) )
test_set_cers[key] = cer test_set_cers[key] = cer