Reformat by black non-streaming zipformer recipe for ksponspeech (#1665)

This commit is contained in:
Seung Hyun Lee 2024-06-24 16:28:09 +09:00 committed by GitHub
parent 6f102d3470
commit 031f892796
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 16 additions and 8 deletions

View File

@ -571,7 +571,9 @@ def save_results(
# ref/hyp pairs. # ref/hyp pairs.
errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt"
with open(errs_filename, "w") as f: with open(errs_filename, "w") as f:
cer = write_error_stats(f, f"{test_set_name}-{key}", results, compute_CER=True) cer = write_error_stats(
f, f"{test_set_name}-{key}", results, compute_CER=True
)
test_set_cers[key] = cer test_set_cers[key] = cer
logging.info("Wrote detailed error stats to {}".format(errs_filename)) logging.info("Wrote detailed error stats to {}".format(errs_filename))
@ -807,7 +809,7 @@ def main():
# we need cut ids to display recognition results. # we need cut ids to display recognition results.
args.return_cuts = True args.return_cuts = True
ksponspeech = KsponSpeechAsrDataModule(args) ksponspeech = KsponSpeechAsrDataModule(args)
eval_clean_cuts = ksponspeech.eval_clean_cuts() eval_clean_cuts = ksponspeech.eval_clean_cuts()
@ -815,7 +817,7 @@ def main():
eval_clean_dl = ksponspeech.test_dataloaders(eval_clean_cuts) eval_clean_dl = ksponspeech.test_dataloaders(eval_clean_cuts)
eval_other_dl = ksponspeech.test_dataloaders(eval_other_cuts) eval_other_dl = ksponspeech.test_dataloaders(eval_other_cuts)
test_sets = ["eval_clean", "eval_other"] test_sets = ["eval_clean", "eval_other"]
test_dl = [eval_clean_dl, eval_other_dl] test_dl = [eval_clean_dl, eval_other_dl]

View File

@ -727,7 +727,11 @@ def save_results(
) )
with open(errs_filename, "w") as f: with open(errs_filename, "w") as f:
cer = write_error_stats( cer = write_error_stats(
f, f"{test_set_name}-{key}", results, enable_log=True, compute_CER=True, f,
f"{test_set_name}-{key}",
results,
enable_log=True,
compute_CER=True,
) )
test_set_cers[key] = cer test_set_cers[key] = cer

View File

@ -659,7 +659,11 @@ def save_results(
) )
with open(errs_filename, "w") as f: with open(errs_filename, "w") as f:
cer = write_error_stats( cer = write_error_stats(
f, f"{test_set_name}-{key}", results, enable_log=True, compute_CER=True, f,
f"{test_set_name}-{key}",
results,
enable_log=True,
compute_CER=True,
) )
test_set_cers[key] = cer test_set_cers[key] = cer

View File

@ -961,9 +961,7 @@ def train_one_epoch(
scaler.update() scaler.update()
optimizer.zero_grad() optimizer.zero_grad()
except Exception as e: except Exception as e:
logging.info( logging.info(f"Caught exception: {e}.")
f"Caught exception: {e}."
)
save_bad_model() save_bad_model()
display_and_save_batch(batch, params=params, sp=sp) display_and_save_batch(batch, params=params, sp=sp)
raise raise