mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 10:16:14 +00:00
Reformat by black
This commit is contained in:
parent
ee21954c15
commit
db5c61d371
@ -63,7 +63,7 @@ def get_args():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--data-dir",
|
"--data-dir",
|
||||||
type=str,
|
type=str,
|
||||||
default='data',
|
default="data",
|
||||||
help="""Path of data directory""",
|
help="""Path of data directory""",
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -74,10 +74,10 @@ def compute_fbank_speechtools(
|
|||||||
bpe_model: Optional[str] = None,
|
bpe_model: Optional[str] = None,
|
||||||
dataset: Optional[str] = None,
|
dataset: Optional[str] = None,
|
||||||
perturb_speed: Optional[bool] = False,
|
perturb_speed: Optional[bool] = False,
|
||||||
data_dir: Optional[str] = 'data',
|
data_dir: Optional[str] = "data",
|
||||||
):
|
):
|
||||||
src_dir = Path(data_dir) / "manifests"
|
src_dir = Path(data_dir) / "manifests"
|
||||||
output_dir = Path(data_dir ) / "fbank"
|
output_dir = Path(data_dir) / "fbank"
|
||||||
num_jobs = min(4, os.cpu_count())
|
num_jobs = min(4, os.cpu_count())
|
||||||
num_mel_bins = 80
|
num_mel_bins = 80
|
||||||
|
|
||||||
@ -116,11 +116,11 @@ def compute_fbank_speechtools(
|
|||||||
|
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
# Use cuda for fbank compute
|
# Use cuda for fbank compute
|
||||||
device = 'cuda'
|
device = "cuda"
|
||||||
else:
|
else:
|
||||||
device = 'cpu'
|
device = "cpu"
|
||||||
logging.info(f"Device: {device}")
|
logging.info(f"Device: {device}")
|
||||||
|
|
||||||
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins, device=device))
|
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins, device=device))
|
||||||
|
|
||||||
with get_executor() as ex: # Initialize the executor only once.
|
with get_executor() as ex: # Initialize the executor only once.
|
||||||
@ -135,9 +135,11 @@ def compute_fbank_speechtools(
|
|||||||
recordings=m["recordings"],
|
recordings=m["recordings"],
|
||||||
supervisions=m["supervisions"],
|
supervisions=m["supervisions"],
|
||||||
)
|
)
|
||||||
|
|
||||||
# Filter duration
|
# Filter duration
|
||||||
cut_set = cut_set.filter(lambda x: x.duration > 1 and x.sampling_rate == 16000)
|
cut_set = cut_set.filter(
|
||||||
|
lambda x: x.duration > 1 and x.sampling_rate == 16000
|
||||||
|
)
|
||||||
|
|
||||||
if "train" in partition:
|
if "train" in partition:
|
||||||
if bpe_model:
|
if bpe_model:
|
||||||
@ -150,7 +152,7 @@ def compute_fbank_speechtools(
|
|||||||
+ cut_set.perturb_speed(1.1)
|
+ cut_set.perturb_speed(1.1)
|
||||||
)
|
)
|
||||||
logging.info(f"Compute & Store features...")
|
logging.info(f"Compute & Store features...")
|
||||||
if device == 'cuda':
|
if device == "cuda":
|
||||||
cut_set = cut_set.compute_and_store_features_batch(
|
cut_set = cut_set.compute_and_store_features_batch(
|
||||||
extractor=extractor,
|
extractor=extractor,
|
||||||
storage_path=f"{output_dir}/{prefix}_feats_{partition}",
|
storage_path=f"{output_dir}/{prefix}_feats_{partition}",
|
||||||
|
@ -53,8 +53,8 @@ def is_cut_long(c: MonoCut) -> bool:
|
|||||||
def compute_fbank_musan(
|
def compute_fbank_musan(
|
||||||
src_dir: str = "data/manifests",
|
src_dir: str = "data/manifests",
|
||||||
num_mel_bins: int = 80,
|
num_mel_bins: int = 80,
|
||||||
whisper_fbank: bool = False,
|
whisper_fbank: bool = False,
|
||||||
output_dir: str = "data/fbank"
|
output_dir: str = "data/fbank",
|
||||||
):
|
):
|
||||||
src_dir = Path(src_dir)
|
src_dir = Path(src_dir)
|
||||||
output_dir = Path(output_dir)
|
output_dir = Path(output_dir)
|
||||||
|
@ -399,14 +399,14 @@ class KsponSpeechAsrDataModule:
|
|||||||
return load_manifest_lazy(
|
return load_manifest_lazy(
|
||||||
self.args.manifest_dir / "ksponspeech_cuts_dev.jsonl.gz"
|
self.args.manifest_dir / "ksponspeech_cuts_dev.jsonl.gz"
|
||||||
)
|
)
|
||||||
|
|
||||||
@lru_cache()
|
@lru_cache()
|
||||||
def eval_clean_cuts(self) -> CutSet:
|
def eval_clean_cuts(self) -> CutSet:
|
||||||
logging.info("About to get eval_clean cuts")
|
logging.info("About to get eval_clean cuts")
|
||||||
return load_manifest_lazy(
|
return load_manifest_lazy(
|
||||||
self.args.manifest_dir / "ksponspeech_cuts_eval_clean.jsonl.gz"
|
self.args.manifest_dir / "ksponspeech_cuts_eval_clean.jsonl.gz"
|
||||||
)
|
)
|
||||||
|
|
||||||
@lru_cache()
|
@lru_cache()
|
||||||
def eval_other_cuts(self) -> CutSet:
|
def eval_other_cuts(self) -> CutSet:
|
||||||
logging.info("About to get eval_other cuts")
|
logging.info("About to get eval_other cuts")
|
||||||
|
@ -693,7 +693,11 @@ def save_results(
|
|||||||
errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt"
|
errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt"
|
||||||
with open(errs_filename, "w") as f:
|
with open(errs_filename, "w") as f:
|
||||||
cer = write_error_stats(
|
cer = write_error_stats(
|
||||||
f, f"{test_set_name}-{key}", results, enable_log=True, compute_CER=True,
|
f,
|
||||||
|
f"{test_set_name}-{key}",
|
||||||
|
results,
|
||||||
|
enable_log=True,
|
||||||
|
compute_CER=True,
|
||||||
)
|
)
|
||||||
test_set_cers[key] = cer
|
test_set_cers[key] = cer
|
||||||
|
|
||||||
|
@ -321,7 +321,6 @@ def decode_dataset(
|
|||||||
num_mel_bins=80,
|
num_mel_bins=80,
|
||||||
high_freq=-400.0,
|
high_freq=-400.0,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
log_interval = 50
|
log_interval = 50
|
||||||
|
|
||||||
@ -426,7 +425,11 @@ def save_results(
|
|||||||
errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt"
|
errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt"
|
||||||
with open(errs_filename, "w") as f:
|
with open(errs_filename, "w") as f:
|
||||||
cer = write_error_stats(
|
cer = write_error_stats(
|
||||||
f, f"{test_set_name}-{key}", results, enable_log=True, compute_CER=True,
|
f,
|
||||||
|
f"{test_set_name}-{key}",
|
||||||
|
results,
|
||||||
|
enable_log=True,
|
||||||
|
compute_CER=True,
|
||||||
)
|
)
|
||||||
test_set_cers[key] = cer
|
test_set_cers[key] = cer
|
||||||
|
|
||||||
|
@ -1031,7 +1031,7 @@ def run(rank, world_size, args):
|
|||||||
ksponspeech = KsponSpeechAsrDataModule(args)
|
ksponspeech = KsponSpeechAsrDataModule(args)
|
||||||
|
|
||||||
train_cuts = ksponspeech.train_cuts()
|
train_cuts = ksponspeech.train_cuts()
|
||||||
|
|
||||||
def remove_short_and_long_utt(c: Cut):
|
def remove_short_and_long_utt(c: Cut):
|
||||||
# Keep only utterances with duration between 1 second and 20 seconds
|
# Keep only utterances with duration between 1 second and 20 seconds
|
||||||
#
|
#
|
||||||
@ -1083,7 +1083,7 @@ def run(rank, world_size, args):
|
|||||||
)
|
)
|
||||||
|
|
||||||
valid_cuts = ksponspeech.dev_cuts()
|
valid_cuts = ksponspeech.dev_cuts()
|
||||||
|
|
||||||
# valid_cuts = valid_cuts.filter(remove_short_and_long_utt)
|
# valid_cuts = valid_cuts.filter(remove_short_and_long_utt)
|
||||||
valid_dl = ksponspeech.valid_dataloaders(valid_cuts)
|
valid_dl = ksponspeech.valid_dataloaders(valid_cuts)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user