mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-06 23:54:17 +00:00
fix style
This commit is contained in:
parent
7a9c18fc79
commit
9ef5dfe380
@ -52,18 +52,15 @@ import sentencepiece as spm
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from asr_datamodule import LibriHeavyAsrDataModule
|
||||
from beam_search import (
|
||||
greedy_search,
|
||||
greedy_search_batch,
|
||||
modified_beam_search,
|
||||
)
|
||||
from beam_search import greedy_search, greedy_search_batch, modified_beam_search
|
||||
from ls_text_normalization import word_normalization
|
||||
from text_normalization import ref_text_normalization, remove_non_alphabetic, upper_only_alpha
|
||||
from train_baseline import (
|
||||
add_model_arguments,
|
||||
get_params,
|
||||
get_transducer_model,
|
||||
from text_normalization import (
|
||||
ref_text_normalization,
|
||||
remove_non_alphabetic,
|
||||
upper_only_alpha,
|
||||
)
|
||||
from train_baseline import add_model_arguments, get_params, get_transducer_model
|
||||
from utils import write_error_stats
|
||||
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
@ -72,13 +69,7 @@ from icefall.checkpoint import (
|
||||
load_checkpoint,
|
||||
)
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
setup_logger,
|
||||
store_transcripts,
|
||||
str2bool,
|
||||
)
|
||||
from utils import write_error_stats
|
||||
from icefall.utils import AttributeDict, setup_logger, store_transcripts, str2bool
|
||||
|
||||
LOG_EPS = math.log(1e-10)
|
||||
|
||||
@ -260,7 +251,7 @@ def get_parser():
|
||||
"--use-ls-test-set",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="Use librispeech test set for evaluation."
|
||||
help="Use librispeech test set for evaluation.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
@ -373,12 +364,6 @@ def decode_one_batch(
|
||||
encoder_out=encoder_out_i,
|
||||
max_sym_per_frame=params.max_sym_per_frame,
|
||||
)
|
||||
elif params.decoding_method == "beam_search":
|
||||
hyp = beam_search(
|
||||
model=model,
|
||||
encoder_out=encoder_out_i,
|
||||
beam=params.beam_size,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported decoding method: {params.decoding_method}"
|
||||
@ -442,7 +427,9 @@ def decode_dataset(
|
||||
texts = batch["supervisions"]["text"]
|
||||
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
||||
if not params.use_ls_test_set:
|
||||
book_names = [cut.text_path.split('/')[-2] for cut in batch["supervisions"]["cut"]]
|
||||
book_names = [
|
||||
cut.text_path.split("/")[-2] for cut in batch["supervisions"]["cut"]
|
||||
]
|
||||
else:
|
||||
book_names = ["" for _ in cut_ids]
|
||||
|
||||
@ -458,10 +445,10 @@ def decode_dataset(
|
||||
for name, hyps in hyps_dict.items():
|
||||
this_batch = []
|
||||
assert len(hyps) == len(texts)
|
||||
for cut_id, book_name, hyp_words, ref_text in zip(cut_ids, book_names, hyps, texts):
|
||||
ref_text = ref_text_normalization(
|
||||
ref_text
|
||||
)
|
||||
for cut_id, book_name, hyp_words, ref_text in zip(
|
||||
cut_ids, book_names, hyps, texts
|
||||
):
|
||||
ref_text = ref_text_normalization(ref_text)
|
||||
ref_words = ref_text.split()
|
||||
this_batch.append((cut_id, ref_words, hyp_words))
|
||||
# if not params.use_ls_test_set:
|
||||
@ -496,7 +483,11 @@ def save_results(
|
||||
errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt"
|
||||
with open(errs_filename, "w") as f:
|
||||
wer = write_error_stats(
|
||||
f, f"{test_set_name}-{key}", results, enable_log=True, biasing_words=biasing_words,
|
||||
f,
|
||||
f"{test_set_name}-{key}",
|
||||
results,
|
||||
enable_log=True,
|
||||
biasing_words=biasing_words,
|
||||
)
|
||||
test_set_wers[key] = wer
|
||||
|
||||
@ -504,7 +495,9 @@ def save_results(
|
||||
|
||||
if params.compute_CER:
|
||||
# Write CER statistics
|
||||
recog_path = params.res_dir / f"recogs-{test_set_name}-char-{params.suffix}.txt"
|
||||
recog_path = (
|
||||
params.res_dir / f"recogs-{test_set_name}-char-{params.suffix}.txt"
|
||||
)
|
||||
store_transcripts(filename=recog_path, texts=results, char_level=True)
|
||||
errs_filename = (
|
||||
params.res_dir / f"errs-CER-{test_set_name}-{params.suffix}.txt"
|
||||
@ -522,9 +515,7 @@ def save_results(
|
||||
logging.info("Wrote detailed CER stats to {}".format(errs_filename))
|
||||
|
||||
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
||||
errs_info = (
|
||||
params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt"
|
||||
)
|
||||
errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt"
|
||||
with open(errs_info, "w") as f:
|
||||
print("settings\tWER", file=f)
|
||||
for key, val in test_set_wers:
|
||||
@ -539,9 +530,7 @@ def save_results(
|
||||
|
||||
if params.compute_CER:
|
||||
test_set_cers = sorted(test_set_cers.items(), key=lambda x: x[1])
|
||||
errs_info = (
|
||||
params.res_dir / f"cer-summary-{test_set_name}-{params.suffix}.txt"
|
||||
)
|
||||
errs_info = params.res_dir / f"cer-summary-{test_set_name}-{params.suffix}.txt"
|
||||
with open(errs_info, "w") as f:
|
||||
print("settings\tcER", file=f)
|
||||
for key, val in test_set_cers:
|
||||
@ -580,9 +569,6 @@ def main():
|
||||
else:
|
||||
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
||||
|
||||
if params.random_left_padding:
|
||||
params.suffix += f"random-left-padding"
|
||||
|
||||
if "beam_search" in params.decoding_method:
|
||||
params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
|
||||
else:
|
||||
@ -717,11 +703,17 @@ def main():
|
||||
ls_test_other_cuts = libriheavy.librispeech_test_other_cuts()
|
||||
long_audio_cuts = libriheavy.long_audio_cuts()
|
||||
|
||||
test_clean_dl = libriheavy.valid_dataloaders(test_clean_cuts,)
|
||||
test_other_dl = libriheavy.valid_dataloaders(test_other_cuts,)
|
||||
test_clean_dl = libriheavy.valid_dataloaders(
|
||||
test_clean_cuts,
|
||||
)
|
||||
test_other_dl = libriheavy.valid_dataloaders(
|
||||
test_other_cuts,
|
||||
)
|
||||
ls_test_clean_dl = libriheavy.test_dataloaders(ls_test_clean_cuts)
|
||||
ls_test_other_dl = libriheavy.test_dataloaders(ls_test_other_cuts)
|
||||
long_audio_dl = libriheavy.valid_dataloaders(long_audio_cuts,)
|
||||
long_audio_dl = libriheavy.valid_dataloaders(
|
||||
long_audio_cuts,
|
||||
)
|
||||
|
||||
if params.use_ls_test_set:
|
||||
test_sets = ["ls-test-clean", "ls-test-other"]
|
||||
@ -736,7 +728,9 @@ def main():
|
||||
|
||||
for test_set, test_dl in zip(test_sets, test_dl):
|
||||
if params.use_ls_test_set:
|
||||
f = open("data/context_biasing/LibriSpeechBiasingLists/all_rare_words.txt", 'r')
|
||||
f = open(
|
||||
"data/context_biasing/LibriSpeechBiasingLists/all_rare_words.txt", "r"
|
||||
)
|
||||
biasing_words = f.read().strip().split()
|
||||
f.close()
|
||||
else:
|
||||
@ -766,7 +760,9 @@ def main():
|
||||
for item in results_dict[k]:
|
||||
id, ref, hyp = item
|
||||
if params.use_ls_test_set:
|
||||
hyp = " ".join(hyp).replace("-", " ").split() # handle the hypens
|
||||
hyp = (
|
||||
" ".join(hyp).replace("-", " ").split()
|
||||
) # handle the hypens
|
||||
hyp = upper_only_alpha(" ".join(hyp)).split()
|
||||
hyp = [word_normalization(w.upper()) for w in hyp]
|
||||
hyp = " ".join(hyp).split()
|
||||
@ -775,7 +771,7 @@ def main():
|
||||
else:
|
||||
hyp = upper_only_alpha(" ".join(hyp)).split()
|
||||
ref = upper_only_alpha(" ".join(ref)).split()
|
||||
new_ans.append((id,ref,hyp))
|
||||
new_ans.append((id, ref, hyp))
|
||||
new_res[k] = new_ans
|
||||
|
||||
save_results(
|
||||
@ -788,7 +784,6 @@ def main():
|
||||
if params.suffix.endswith("-post-normalization"):
|
||||
params.suffix = params.suffix.replace("-post-normalization", "")
|
||||
|
||||
|
||||
logging.info("Done!")
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user