mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-07 08:04:18 +00:00
fix style
This commit is contained in:
parent
7a9c18fc79
commit
9ef5dfe380
@ -52,18 +52,15 @@ import sentencepiece as spm
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from asr_datamodule import LibriHeavyAsrDataModule
|
from asr_datamodule import LibriHeavyAsrDataModule
|
||||||
from beam_search import (
|
from beam_search import greedy_search, greedy_search_batch, modified_beam_search
|
||||||
greedy_search,
|
|
||||||
greedy_search_batch,
|
|
||||||
modified_beam_search,
|
|
||||||
)
|
|
||||||
from ls_text_normalization import word_normalization
|
from ls_text_normalization import word_normalization
|
||||||
from text_normalization import ref_text_normalization, remove_non_alphabetic, upper_only_alpha
|
from text_normalization import (
|
||||||
from train_baseline import (
|
ref_text_normalization,
|
||||||
add_model_arguments,
|
remove_non_alphabetic,
|
||||||
get_params,
|
upper_only_alpha,
|
||||||
get_transducer_model,
|
|
||||||
)
|
)
|
||||||
|
from train_baseline import add_model_arguments, get_params, get_transducer_model
|
||||||
|
from utils import write_error_stats
|
||||||
|
|
||||||
from icefall.checkpoint import (
|
from icefall.checkpoint import (
|
||||||
average_checkpoints,
|
average_checkpoints,
|
||||||
@ -72,13 +69,7 @@ from icefall.checkpoint import (
|
|||||||
load_checkpoint,
|
load_checkpoint,
|
||||||
)
|
)
|
||||||
from icefall.lexicon import Lexicon
|
from icefall.lexicon import Lexicon
|
||||||
from icefall.utils import (
|
from icefall.utils import AttributeDict, setup_logger, store_transcripts, str2bool
|
||||||
AttributeDict,
|
|
||||||
setup_logger,
|
|
||||||
store_transcripts,
|
|
||||||
str2bool,
|
|
||||||
)
|
|
||||||
from utils import write_error_stats
|
|
||||||
|
|
||||||
LOG_EPS = math.log(1e-10)
|
LOG_EPS = math.log(1e-10)
|
||||||
|
|
||||||
@ -242,34 +233,34 @@ def get_parser():
|
|||||||
Used only when the decoding method is fast_beam_search_nbest,
|
Used only when the decoding method is fast_beam_search_nbest,
|
||||||
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
|
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--post-normalization",
|
"--post-normalization",
|
||||||
type=str2bool,
|
type=str2bool,
|
||||||
default=True,
|
default=True,
|
||||||
help="Normalized the recognition results by uppercasing and removing non-alphabetic symbols. ",
|
help="Normalized the recognition results by uppercasing and removing non-alphabetic symbols. ",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--long-audio-recog",
|
"--long-audio-recog",
|
||||||
type=str2bool,
|
type=str2bool,
|
||||||
default=False,
|
default=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--use-ls-test-set",
|
"--use-ls-test-set",
|
||||||
type=str2bool,
|
type=str2bool,
|
||||||
default=False,
|
default=False,
|
||||||
help="Use librispeech test set for evaluation."
|
help="Use librispeech test set for evaluation.",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--compute-CER",
|
"--compute-CER",
|
||||||
type=str2bool,
|
type=str2bool,
|
||||||
default=True,
|
default=True,
|
||||||
help="Reports CER. By default, only reports WER",
|
help="Reports CER. By default, only reports WER",
|
||||||
)
|
)
|
||||||
|
|
||||||
add_model_arguments(parser)
|
add_model_arguments(parser)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
@ -333,14 +324,14 @@ def decode_one_batch(
|
|||||||
|
|
||||||
supervisions = batch["supervisions"]
|
supervisions = batch["supervisions"]
|
||||||
feature_lens = supervisions["num_frames"].to(device)
|
feature_lens = supervisions["num_frames"].to(device)
|
||||||
|
|
||||||
with warnings.catch_warnings():
|
with warnings.catch_warnings():
|
||||||
warnings.simplefilter("ignore")
|
warnings.simplefilter("ignore")
|
||||||
encoder_out, encoder_out_lens = model.encode_audio(
|
encoder_out, encoder_out_lens = model.encode_audio(
|
||||||
feature=feature,
|
feature=feature,
|
||||||
feature_lens=feature_lens,
|
feature_lens=feature_lens,
|
||||||
)
|
)
|
||||||
|
|
||||||
hyps = []
|
hyps = []
|
||||||
|
|
||||||
if params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
|
if params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
|
||||||
@ -373,12 +364,6 @@ def decode_one_batch(
|
|||||||
encoder_out=encoder_out_i,
|
encoder_out=encoder_out_i,
|
||||||
max_sym_per_frame=params.max_sym_per_frame,
|
max_sym_per_frame=params.max_sym_per_frame,
|
||||||
)
|
)
|
||||||
elif params.decoding_method == "beam_search":
|
|
||||||
hyp = beam_search(
|
|
||||||
model=model,
|
|
||||||
encoder_out=encoder_out_i,
|
|
||||||
beam=params.beam_size,
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Unsupported decoding method: {params.decoding_method}"
|
f"Unsupported decoding method: {params.decoding_method}"
|
||||||
@ -442,7 +427,9 @@ def decode_dataset(
|
|||||||
texts = batch["supervisions"]["text"]
|
texts = batch["supervisions"]["text"]
|
||||||
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
||||||
if not params.use_ls_test_set:
|
if not params.use_ls_test_set:
|
||||||
book_names = [cut.text_path.split('/')[-2] for cut in batch["supervisions"]["cut"]]
|
book_names = [
|
||||||
|
cut.text_path.split("/")[-2] for cut in batch["supervisions"]["cut"]
|
||||||
|
]
|
||||||
else:
|
else:
|
||||||
book_names = ["" for _ in cut_ids]
|
book_names = ["" for _ in cut_ids]
|
||||||
|
|
||||||
@ -458,10 +445,10 @@ def decode_dataset(
|
|||||||
for name, hyps in hyps_dict.items():
|
for name, hyps in hyps_dict.items():
|
||||||
this_batch = []
|
this_batch = []
|
||||||
assert len(hyps) == len(texts)
|
assert len(hyps) == len(texts)
|
||||||
for cut_id, book_name, hyp_words, ref_text in zip(cut_ids, book_names, hyps, texts):
|
for cut_id, book_name, hyp_words, ref_text in zip(
|
||||||
ref_text = ref_text_normalization(
|
cut_ids, book_names, hyps, texts
|
||||||
ref_text
|
):
|
||||||
)
|
ref_text = ref_text_normalization(ref_text)
|
||||||
ref_words = ref_text.split()
|
ref_words = ref_text.split()
|
||||||
this_batch.append((cut_id, ref_words, hyp_words))
|
this_batch.append((cut_id, ref_words, hyp_words))
|
||||||
# if not params.use_ls_test_set:
|
# if not params.use_ls_test_set:
|
||||||
@ -496,7 +483,11 @@ def save_results(
|
|||||||
errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt"
|
errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt"
|
||||||
with open(errs_filename, "w") as f:
|
with open(errs_filename, "w") as f:
|
||||||
wer = write_error_stats(
|
wer = write_error_stats(
|
||||||
f, f"{test_set_name}-{key}", results, enable_log=True, biasing_words=biasing_words,
|
f,
|
||||||
|
f"{test_set_name}-{key}",
|
||||||
|
results,
|
||||||
|
enable_log=True,
|
||||||
|
biasing_words=biasing_words,
|
||||||
)
|
)
|
||||||
test_set_wers[key] = wer
|
test_set_wers[key] = wer
|
||||||
|
|
||||||
@ -504,7 +495,9 @@ def save_results(
|
|||||||
|
|
||||||
if params.compute_CER:
|
if params.compute_CER:
|
||||||
# Write CER statistics
|
# Write CER statistics
|
||||||
recog_path = params.res_dir / f"recogs-{test_set_name}-char-{params.suffix}.txt"
|
recog_path = (
|
||||||
|
params.res_dir / f"recogs-{test_set_name}-char-{params.suffix}.txt"
|
||||||
|
)
|
||||||
store_transcripts(filename=recog_path, texts=results, char_level=True)
|
store_transcripts(filename=recog_path, texts=results, char_level=True)
|
||||||
errs_filename = (
|
errs_filename = (
|
||||||
params.res_dir / f"errs-CER-{test_set_name}-{params.suffix}.txt"
|
params.res_dir / f"errs-CER-{test_set_name}-{params.suffix}.txt"
|
||||||
@ -522,9 +515,7 @@ def save_results(
|
|||||||
logging.info("Wrote detailed CER stats to {}".format(errs_filename))
|
logging.info("Wrote detailed CER stats to {}".format(errs_filename))
|
||||||
|
|
||||||
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
||||||
errs_info = (
|
errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt"
|
||||||
params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt"
|
|
||||||
)
|
|
||||||
with open(errs_info, "w") as f:
|
with open(errs_info, "w") as f:
|
||||||
print("settings\tWER", file=f)
|
print("settings\tWER", file=f)
|
||||||
for key, val in test_set_wers:
|
for key, val in test_set_wers:
|
||||||
@ -539,9 +530,7 @@ def save_results(
|
|||||||
|
|
||||||
if params.compute_CER:
|
if params.compute_CER:
|
||||||
test_set_cers = sorted(test_set_cers.items(), key=lambda x: x[1])
|
test_set_cers = sorted(test_set_cers.items(), key=lambda x: x[1])
|
||||||
errs_info = (
|
errs_info = params.res_dir / f"cer-summary-{test_set_name}-{params.suffix}.txt"
|
||||||
params.res_dir / f"cer-summary-{test_set_name}-{params.suffix}.txt"
|
|
||||||
)
|
|
||||||
with open(errs_info, "w") as f:
|
with open(errs_info, "w") as f:
|
||||||
print("settings\tcER", file=f)
|
print("settings\tcER", file=f)
|
||||||
for key, val in test_set_cers:
|
for key, val in test_set_cers:
|
||||||
@ -569,7 +558,7 @@ def main():
|
|||||||
"greedy_search",
|
"greedy_search",
|
||||||
"modified_beam_search",
|
"modified_beam_search",
|
||||||
)
|
)
|
||||||
|
|
||||||
if params.long_audio_recog:
|
if params.long_audio_recog:
|
||||||
params.res_dir = params.exp_dir / (params.decoding_method + "long_audio")
|
params.res_dir = params.exp_dir / (params.decoding_method + "long_audio")
|
||||||
else:
|
else:
|
||||||
@ -579,9 +568,6 @@ def main():
|
|||||||
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
|
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
|
||||||
else:
|
else:
|
||||||
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
||||||
|
|
||||||
if params.random_left_padding:
|
|
||||||
params.suffix += f"random-left-padding"
|
|
||||||
|
|
||||||
if "beam_search" in params.decoding_method:
|
if "beam_search" in params.decoding_method:
|
||||||
params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
|
params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
|
||||||
@ -591,7 +577,7 @@ def main():
|
|||||||
|
|
||||||
if "ngram" in params.decoding_method:
|
if "ngram" in params.decoding_method:
|
||||||
params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
|
params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
|
||||||
|
|
||||||
if params.use_averaged_model:
|
if params.use_averaged_model:
|
||||||
params.suffix += "-use-averaged-model"
|
params.suffix += "-use-averaged-model"
|
||||||
|
|
||||||
@ -706,7 +692,7 @@ def main():
|
|||||||
|
|
||||||
num_param = sum([p.numel() for p in model.parameters()])
|
num_param = sum([p.numel() for p in model.parameters()])
|
||||||
logging.info(f"Number of model parameters: {num_param}")
|
logging.info(f"Number of model parameters: {num_param}")
|
||||||
|
|
||||||
# we need cut ids to display recognition results.
|
# we need cut ids to display recognition results.
|
||||||
args.return_cuts = True
|
args.return_cuts = True
|
||||||
libriheavy = LibriHeavyAsrDataModule(args)
|
libriheavy = LibriHeavyAsrDataModule(args)
|
||||||
@ -717,26 +703,34 @@ def main():
|
|||||||
ls_test_other_cuts = libriheavy.librispeech_test_other_cuts()
|
ls_test_other_cuts = libriheavy.librispeech_test_other_cuts()
|
||||||
long_audio_cuts = libriheavy.long_audio_cuts()
|
long_audio_cuts = libriheavy.long_audio_cuts()
|
||||||
|
|
||||||
test_clean_dl = libriheavy.valid_dataloaders(test_clean_cuts,)
|
test_clean_dl = libriheavy.valid_dataloaders(
|
||||||
test_other_dl = libriheavy.valid_dataloaders(test_other_cuts,)
|
test_clean_cuts,
|
||||||
|
)
|
||||||
|
test_other_dl = libriheavy.valid_dataloaders(
|
||||||
|
test_other_cuts,
|
||||||
|
)
|
||||||
ls_test_clean_dl = libriheavy.test_dataloaders(ls_test_clean_cuts)
|
ls_test_clean_dl = libriheavy.test_dataloaders(ls_test_clean_cuts)
|
||||||
ls_test_other_dl = libriheavy.test_dataloaders(ls_test_other_cuts)
|
ls_test_other_dl = libriheavy.test_dataloaders(ls_test_other_cuts)
|
||||||
long_audio_dl = libriheavy.valid_dataloaders(long_audio_cuts,)
|
long_audio_dl = libriheavy.valid_dataloaders(
|
||||||
|
long_audio_cuts,
|
||||||
|
)
|
||||||
|
|
||||||
if params.use_ls_test_set:
|
if params.use_ls_test_set:
|
||||||
test_sets = ["ls-test-clean", "ls-test-other"]
|
test_sets = ["ls-test-clean", "ls-test-other"]
|
||||||
test_dl = [ls_test_clean_dl, ls_test_other_dl]
|
test_dl = [ls_test_clean_dl, ls_test_other_dl]
|
||||||
else:
|
else:
|
||||||
test_sets = ["test-clean", "test-other"]
|
test_sets = ["test-clean", "test-other"]
|
||||||
test_dl = [test_clean_dl, test_other_dl]
|
test_dl = [test_clean_dl, test_other_dl]
|
||||||
|
|
||||||
if params.long_audio_recog:
|
if params.long_audio_recog:
|
||||||
test_sets = ["long-audio"]
|
test_sets = ["long-audio"]
|
||||||
test_dl = [long_audio_dl]
|
test_dl = [long_audio_dl]
|
||||||
|
|
||||||
for test_set, test_dl in zip(test_sets, test_dl):
|
for test_set, test_dl in zip(test_sets, test_dl):
|
||||||
if params.use_ls_test_set:
|
if params.use_ls_test_set:
|
||||||
f = open("data/context_biasing/LibriSpeechBiasingLists/all_rare_words.txt", 'r')
|
f = open(
|
||||||
|
"data/context_biasing/LibriSpeechBiasingLists/all_rare_words.txt", "r"
|
||||||
|
)
|
||||||
biasing_words = f.read().strip().split()
|
biasing_words = f.read().strip().split()
|
||||||
f.close()
|
f.close()
|
||||||
else:
|
else:
|
||||||
@ -747,7 +741,7 @@ def main():
|
|||||||
model=model,
|
model=model,
|
||||||
sp=sp,
|
sp=sp,
|
||||||
word_table=word_table,
|
word_table=word_table,
|
||||||
decoding_graph=decoding_graph,
|
decoding_graph=decoding_graph,
|
||||||
)
|
)
|
||||||
|
|
||||||
save_results(
|
save_results(
|
||||||
@ -755,40 +749,41 @@ def main():
|
|||||||
test_set_name=test_set,
|
test_set_name=test_set,
|
||||||
results_dict=results_dict,
|
results_dict=results_dict,
|
||||||
)
|
)
|
||||||
|
|
||||||
if params.post_normalization:
|
if params.post_normalization:
|
||||||
if "-post-normalization" not in params.suffix:
|
if "-post-normalization" not in params.suffix:
|
||||||
params.suffix += "-post-normalization"
|
params.suffix += "-post-normalization"
|
||||||
|
|
||||||
new_res = {}
|
new_res = {}
|
||||||
for k in results_dict:
|
for k in results_dict:
|
||||||
new_ans = []
|
new_ans = []
|
||||||
for item in results_dict[k]:
|
for item in results_dict[k]:
|
||||||
id, ref, hyp = item
|
id, ref, hyp = item
|
||||||
if params.use_ls_test_set:
|
if params.use_ls_test_set:
|
||||||
hyp = " ".join(hyp).replace("-", " ").split() # handle the hypens
|
hyp = (
|
||||||
|
" ".join(hyp).replace("-", " ").split()
|
||||||
|
) # handle the hypens
|
||||||
hyp = upper_only_alpha(" ".join(hyp)).split()
|
hyp = upper_only_alpha(" ".join(hyp)).split()
|
||||||
hyp = [word_normalization(w.upper()) for w in hyp]
|
hyp = [word_normalization(w.upper()) for w in hyp]
|
||||||
hyp = " ".join(hyp).split()
|
hyp = " ".join(hyp).split()
|
||||||
hyp = [w for w in hyp if w != ""]
|
hyp = [w for w in hyp if w != ""]
|
||||||
ref = upper_only_alpha(" ".join(ref)).split()
|
ref = upper_only_alpha(" ".join(ref)).split()
|
||||||
else:
|
else:
|
||||||
hyp = upper_only_alpha(" ".join(hyp)).split()
|
hyp = upper_only_alpha(" ".join(hyp)).split()
|
||||||
ref = upper_only_alpha(" ".join(ref)).split()
|
ref = upper_only_alpha(" ".join(ref)).split()
|
||||||
new_ans.append((id,ref,hyp))
|
new_ans.append((id, ref, hyp))
|
||||||
new_res[k] = new_ans
|
new_res[k] = new_ans
|
||||||
|
|
||||||
save_results(
|
save_results(
|
||||||
params=params,
|
params=params,
|
||||||
test_set_name=test_set,
|
test_set_name=test_set,
|
||||||
results_dict=new_res,
|
results_dict=new_res,
|
||||||
biasing_words=biasing_words,
|
biasing_words=biasing_words,
|
||||||
)
|
)
|
||||||
|
|
||||||
if params.suffix.endswith("-post-normalization"):
|
if params.suffix.endswith("-post-normalization"):
|
||||||
params.suffix = params.suffix.replace("-post-normalization", "")
|
params.suffix = params.suffix.replace("-post-normalization", "")
|
||||||
|
|
||||||
|
|
||||||
logging.info("Done!")
|
logging.info("Done!")
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user