fix style

This commit is contained in:
marcoyang 2023-10-10 16:48:17 +08:00
parent 7a9c18fc79
commit 9ef5dfe380

View File

@ -52,18 +52,15 @@ import sentencepiece as spm
import torch
import torch.nn as nn
from asr_datamodule import LibriHeavyAsrDataModule
from beam_search import (
greedy_search,
greedy_search_batch,
modified_beam_search,
)
from beam_search import greedy_search, greedy_search_batch, modified_beam_search
from ls_text_normalization import word_normalization
from text_normalization import ref_text_normalization, remove_non_alphabetic, upper_only_alpha
from train_baseline import (
add_model_arguments,
get_params,
get_transducer_model,
from text_normalization import (
ref_text_normalization,
remove_non_alphabetic,
upper_only_alpha,
)
from train_baseline import add_model_arguments, get_params, get_transducer_model
from utils import write_error_stats
from icefall.checkpoint import (
average_checkpoints,
@ -72,13 +69,7 @@ from icefall.checkpoint import (
load_checkpoint,
)
from icefall.lexicon import Lexicon
from icefall.utils import (
AttributeDict,
setup_logger,
store_transcripts,
str2bool,
)
from utils import write_error_stats
from icefall.utils import AttributeDict, setup_logger, store_transcripts, str2bool
LOG_EPS = math.log(1e-10)
@ -260,7 +251,7 @@ def get_parser():
"--use-ls-test-set",
type=str2bool,
default=False,
help="Use librispeech test set for evaluation."
help="Use librispeech test set for evaluation.",
)
parser.add_argument(
@ -373,12 +364,6 @@ def decode_one_batch(
encoder_out=encoder_out_i,
max_sym_per_frame=params.max_sym_per_frame,
)
elif params.decoding_method == "beam_search":
hyp = beam_search(
model=model,
encoder_out=encoder_out_i,
beam=params.beam_size,
)
else:
raise ValueError(
f"Unsupported decoding method: {params.decoding_method}"
@ -442,7 +427,9 @@ def decode_dataset(
texts = batch["supervisions"]["text"]
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
if not params.use_ls_test_set:
book_names = [cut.text_path.split('/')[-2] for cut in batch["supervisions"]["cut"]]
book_names = [
cut.text_path.split("/")[-2] for cut in batch["supervisions"]["cut"]
]
else:
book_names = ["" for _ in cut_ids]
@ -458,10 +445,10 @@ def decode_dataset(
for name, hyps in hyps_dict.items():
this_batch = []
assert len(hyps) == len(texts)
for cut_id, book_name, hyp_words, ref_text in zip(cut_ids, book_names, hyps, texts):
ref_text = ref_text_normalization(
ref_text
)
for cut_id, book_name, hyp_words, ref_text in zip(
cut_ids, book_names, hyps, texts
):
ref_text = ref_text_normalization(ref_text)
ref_words = ref_text.split()
this_batch.append((cut_id, ref_words, hyp_words))
# if not params.use_ls_test_set:
@ -496,7 +483,11 @@ def save_results(
errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt"
with open(errs_filename, "w") as f:
wer = write_error_stats(
f, f"{test_set_name}-{key}", results, enable_log=True, biasing_words=biasing_words,
f,
f"{test_set_name}-{key}",
results,
enable_log=True,
biasing_words=biasing_words,
)
test_set_wers[key] = wer
@ -504,7 +495,9 @@ def save_results(
if params.compute_CER:
# Write CER statistics
recog_path = params.res_dir / f"recogs-{test_set_name}-char-{params.suffix}.txt"
recog_path = (
params.res_dir / f"recogs-{test_set_name}-char-{params.suffix}.txt"
)
store_transcripts(filename=recog_path, texts=results, char_level=True)
errs_filename = (
params.res_dir / f"errs-CER-{test_set_name}-{params.suffix}.txt"
@ -522,9 +515,7 @@ def save_results(
logging.info("Wrote detailed CER stats to {}".format(errs_filename))
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = (
params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt"
)
errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt"
with open(errs_info, "w") as f:
print("settings\tWER", file=f)
for key, val in test_set_wers:
@ -539,9 +530,7 @@ def save_results(
if params.compute_CER:
test_set_cers = sorted(test_set_cers.items(), key=lambda x: x[1])
errs_info = (
params.res_dir / f"cer-summary-{test_set_name}-{params.suffix}.txt"
)
errs_info = params.res_dir / f"cer-summary-{test_set_name}-{params.suffix}.txt"
with open(errs_info, "w") as f:
print("settings\tcER", file=f)
for key, val in test_set_cers:
@ -580,9 +569,6 @@ def main():
else:
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
if params.random_left_padding:
params.suffix += f"random-left-padding"
if "beam_search" in params.decoding_method:
params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
else:
@ -717,11 +703,17 @@ def main():
ls_test_other_cuts = libriheavy.librispeech_test_other_cuts()
long_audio_cuts = libriheavy.long_audio_cuts()
test_clean_dl = libriheavy.valid_dataloaders(test_clean_cuts,)
test_other_dl = libriheavy.valid_dataloaders(test_other_cuts,)
test_clean_dl = libriheavy.valid_dataloaders(
test_clean_cuts,
)
test_other_dl = libriheavy.valid_dataloaders(
test_other_cuts,
)
ls_test_clean_dl = libriheavy.test_dataloaders(ls_test_clean_cuts)
ls_test_other_dl = libriheavy.test_dataloaders(ls_test_other_cuts)
long_audio_dl = libriheavy.valid_dataloaders(long_audio_cuts,)
long_audio_dl = libriheavy.valid_dataloaders(
long_audio_cuts,
)
if params.use_ls_test_set:
test_sets = ["ls-test-clean", "ls-test-other"]
@ -736,7 +728,9 @@ def main():
for test_set, test_dl in zip(test_sets, test_dl):
if params.use_ls_test_set:
f = open("data/context_biasing/LibriSpeechBiasingLists/all_rare_words.txt", 'r')
f = open(
"data/context_biasing/LibriSpeechBiasingLists/all_rare_words.txt", "r"
)
biasing_words = f.read().strip().split()
f.close()
else:
@ -766,7 +760,9 @@ def main():
for item in results_dict[k]:
id, ref, hyp = item
if params.use_ls_test_set:
hyp = " ".join(hyp).replace("-", " ").split() # handle the hypens
hyp = (
" ".join(hyp).replace("-", " ").split()
) # handle the hypens
hyp = upper_only_alpha(" ".join(hyp)).split()
hyp = [word_normalization(w.upper()) for w in hyp]
hyp = " ".join(hyp).split()
@ -775,7 +771,7 @@ def main():
else:
hyp = upper_only_alpha(" ".join(hyp)).split()
ref = upper_only_alpha(" ".join(ref)).split()
new_ans.append((id,ref,hyp))
new_ans.append((id, ref, hyp))
new_res[k] = new_ans
save_results(
@ -788,7 +784,6 @@ def main():
if params.suffix.endswith("-post-normalization"):
params.suffix = params.suffix.replace("-post-normalization", "")
logging.info("Done!")