fix style

This commit is contained in:
marcoyang 2023-10-10 16:48:17 +08:00
parent 7a9c18fc79
commit 9ef5dfe380

View File

@ -52,18 +52,15 @@ import sentencepiece as spm
import torch import torch
import torch.nn as nn import torch.nn as nn
from asr_datamodule import LibriHeavyAsrDataModule from asr_datamodule import LibriHeavyAsrDataModule
from beam_search import ( from beam_search import greedy_search, greedy_search_batch, modified_beam_search
greedy_search,
greedy_search_batch,
modified_beam_search,
)
from ls_text_normalization import word_normalization from ls_text_normalization import word_normalization
from text_normalization import ref_text_normalization, remove_non_alphabetic, upper_only_alpha from text_normalization import (
from train_baseline import ( ref_text_normalization,
add_model_arguments, remove_non_alphabetic,
get_params, upper_only_alpha,
get_transducer_model,
) )
from train_baseline import add_model_arguments, get_params, get_transducer_model
from utils import write_error_stats
from icefall.checkpoint import ( from icefall.checkpoint import (
average_checkpoints, average_checkpoints,
@ -72,13 +69,7 @@ from icefall.checkpoint import (
load_checkpoint, load_checkpoint,
) )
from icefall.lexicon import Lexicon from icefall.lexicon import Lexicon
from icefall.utils import ( from icefall.utils import AttributeDict, setup_logger, store_transcripts, str2bool
AttributeDict,
setup_logger,
store_transcripts,
str2bool,
)
from utils import write_error_stats
LOG_EPS = math.log(1e-10) LOG_EPS = math.log(1e-10)
@ -242,34 +233,34 @@ def get_parser():
Used only when the decoding method is fast_beam_search_nbest, Used only when the decoding method is fast_beam_search_nbest,
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
) )
parser.add_argument( parser.add_argument(
"--post-normalization", "--post-normalization",
type=str2bool, type=str2bool,
default=True, default=True,
help="Normalized the recognition results by uppercasing and removing non-alphabetic symbols. ", help="Normalized the recognition results by uppercasing and removing non-alphabetic symbols. ",
) )
parser.add_argument( parser.add_argument(
"--long-audio-recog", "--long-audio-recog",
type=str2bool, type=str2bool,
default=False, default=False,
) )
parser.add_argument( parser.add_argument(
"--use-ls-test-set", "--use-ls-test-set",
type=str2bool, type=str2bool,
default=False, default=False,
help="Use librispeech test set for evaluation." help="Use librispeech test set for evaluation.",
) )
parser.add_argument( parser.add_argument(
"--compute-CER", "--compute-CER",
type=str2bool, type=str2bool,
default=True, default=True,
help="Reports CER. By default, only reports WER", help="Reports CER. By default, only reports WER",
) )
add_model_arguments(parser) add_model_arguments(parser)
return parser return parser
@ -333,14 +324,14 @@ def decode_one_batch(
supervisions = batch["supervisions"] supervisions = batch["supervisions"]
feature_lens = supervisions["num_frames"].to(device) feature_lens = supervisions["num_frames"].to(device)
with warnings.catch_warnings(): with warnings.catch_warnings():
warnings.simplefilter("ignore") warnings.simplefilter("ignore")
encoder_out, encoder_out_lens = model.encode_audio( encoder_out, encoder_out_lens = model.encode_audio(
feature=feature, feature=feature,
feature_lens=feature_lens, feature_lens=feature_lens,
) )
hyps = [] hyps = []
if params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: if params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
@ -373,12 +364,6 @@ def decode_one_batch(
encoder_out=encoder_out_i, encoder_out=encoder_out_i,
max_sym_per_frame=params.max_sym_per_frame, max_sym_per_frame=params.max_sym_per_frame,
) )
elif params.decoding_method == "beam_search":
hyp = beam_search(
model=model,
encoder_out=encoder_out_i,
beam=params.beam_size,
)
else: else:
raise ValueError( raise ValueError(
f"Unsupported decoding method: {params.decoding_method}" f"Unsupported decoding method: {params.decoding_method}"
@ -442,7 +427,9 @@ def decode_dataset(
texts = batch["supervisions"]["text"] texts = batch["supervisions"]["text"]
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
if not params.use_ls_test_set: if not params.use_ls_test_set:
book_names = [cut.text_path.split('/')[-2] for cut in batch["supervisions"]["cut"]] book_names = [
cut.text_path.split("/")[-2] for cut in batch["supervisions"]["cut"]
]
else: else:
book_names = ["" for _ in cut_ids] book_names = ["" for _ in cut_ids]
@ -458,10 +445,10 @@ def decode_dataset(
for name, hyps in hyps_dict.items(): for name, hyps in hyps_dict.items():
this_batch = [] this_batch = []
assert len(hyps) == len(texts) assert len(hyps) == len(texts)
for cut_id, book_name, hyp_words, ref_text in zip(cut_ids, book_names, hyps, texts): for cut_id, book_name, hyp_words, ref_text in zip(
ref_text = ref_text_normalization( cut_ids, book_names, hyps, texts
ref_text ):
) ref_text = ref_text_normalization(ref_text)
ref_words = ref_text.split() ref_words = ref_text.split()
this_batch.append((cut_id, ref_words, hyp_words)) this_batch.append((cut_id, ref_words, hyp_words))
# if not params.use_ls_test_set: # if not params.use_ls_test_set:
@ -496,7 +483,11 @@ def save_results(
errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt"
with open(errs_filename, "w") as f: with open(errs_filename, "w") as f:
wer = write_error_stats( wer = write_error_stats(
f, f"{test_set_name}-{key}", results, enable_log=True, biasing_words=biasing_words, f,
f"{test_set_name}-{key}",
results,
enable_log=True,
biasing_words=biasing_words,
) )
test_set_wers[key] = wer test_set_wers[key] = wer
@ -504,7 +495,9 @@ def save_results(
if params.compute_CER: if params.compute_CER:
# Write CER statistics # Write CER statistics
recog_path = params.res_dir / f"recogs-{test_set_name}-char-{params.suffix}.txt" recog_path = (
params.res_dir / f"recogs-{test_set_name}-char-{params.suffix}.txt"
)
store_transcripts(filename=recog_path, texts=results, char_level=True) store_transcripts(filename=recog_path, texts=results, char_level=True)
errs_filename = ( errs_filename = (
params.res_dir / f"errs-CER-{test_set_name}-{params.suffix}.txt" params.res_dir / f"errs-CER-{test_set_name}-{params.suffix}.txt"
@ -522,9 +515,7 @@ def save_results(
logging.info("Wrote detailed CER stats to {}".format(errs_filename)) logging.info("Wrote detailed CER stats to {}".format(errs_filename))
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = ( errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt"
params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt"
)
with open(errs_info, "w") as f: with open(errs_info, "w") as f:
print("settings\tWER", file=f) print("settings\tWER", file=f)
for key, val in test_set_wers: for key, val in test_set_wers:
@ -539,9 +530,7 @@ def save_results(
if params.compute_CER: if params.compute_CER:
test_set_cers = sorted(test_set_cers.items(), key=lambda x: x[1]) test_set_cers = sorted(test_set_cers.items(), key=lambda x: x[1])
errs_info = ( errs_info = params.res_dir / f"cer-summary-{test_set_name}-{params.suffix}.txt"
params.res_dir / f"cer-summary-{test_set_name}-{params.suffix}.txt"
)
with open(errs_info, "w") as f: with open(errs_info, "w") as f:
print("settings\tcER", file=f) print("settings\tcER", file=f)
for key, val in test_set_cers: for key, val in test_set_cers:
@ -569,7 +558,7 @@ def main():
"greedy_search", "greedy_search",
"modified_beam_search", "modified_beam_search",
) )
if params.long_audio_recog: if params.long_audio_recog:
params.res_dir = params.exp_dir / (params.decoding_method + "long_audio") params.res_dir = params.exp_dir / (params.decoding_method + "long_audio")
else: else:
@ -579,9 +568,6 @@ def main():
params.suffix = f"iter-{params.iter}-avg-{params.avg}" params.suffix = f"iter-{params.iter}-avg-{params.avg}"
else: else:
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
if params.random_left_padding:
params.suffix += f"random-left-padding"
if "beam_search" in params.decoding_method: if "beam_search" in params.decoding_method:
params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
@ -591,7 +577,7 @@ def main():
if "ngram" in params.decoding_method: if "ngram" in params.decoding_method:
params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
if params.use_averaged_model: if params.use_averaged_model:
params.suffix += "-use-averaged-model" params.suffix += "-use-averaged-model"
@ -706,7 +692,7 @@ def main():
num_param = sum([p.numel() for p in model.parameters()]) num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}") logging.info(f"Number of model parameters: {num_param}")
# we need cut ids to display recognition results. # we need cut ids to display recognition results.
args.return_cuts = True args.return_cuts = True
libriheavy = LibriHeavyAsrDataModule(args) libriheavy = LibriHeavyAsrDataModule(args)
@ -717,26 +703,34 @@ def main():
ls_test_other_cuts = libriheavy.librispeech_test_other_cuts() ls_test_other_cuts = libriheavy.librispeech_test_other_cuts()
long_audio_cuts = libriheavy.long_audio_cuts() long_audio_cuts = libriheavy.long_audio_cuts()
test_clean_dl = libriheavy.valid_dataloaders(test_clean_cuts,) test_clean_dl = libriheavy.valid_dataloaders(
test_other_dl = libriheavy.valid_dataloaders(test_other_cuts,) test_clean_cuts,
)
test_other_dl = libriheavy.valid_dataloaders(
test_other_cuts,
)
ls_test_clean_dl = libriheavy.test_dataloaders(ls_test_clean_cuts) ls_test_clean_dl = libriheavy.test_dataloaders(ls_test_clean_cuts)
ls_test_other_dl = libriheavy.test_dataloaders(ls_test_other_cuts) ls_test_other_dl = libriheavy.test_dataloaders(ls_test_other_cuts)
long_audio_dl = libriheavy.valid_dataloaders(long_audio_cuts,) long_audio_dl = libriheavy.valid_dataloaders(
long_audio_cuts,
)
if params.use_ls_test_set: if params.use_ls_test_set:
test_sets = ["ls-test-clean", "ls-test-other"] test_sets = ["ls-test-clean", "ls-test-other"]
test_dl = [ls_test_clean_dl, ls_test_other_dl] test_dl = [ls_test_clean_dl, ls_test_other_dl]
else: else:
test_sets = ["test-clean", "test-other"] test_sets = ["test-clean", "test-other"]
test_dl = [test_clean_dl, test_other_dl] test_dl = [test_clean_dl, test_other_dl]
if params.long_audio_recog: if params.long_audio_recog:
test_sets = ["long-audio"] test_sets = ["long-audio"]
test_dl = [long_audio_dl] test_dl = [long_audio_dl]
for test_set, test_dl in zip(test_sets, test_dl): for test_set, test_dl in zip(test_sets, test_dl):
if params.use_ls_test_set: if params.use_ls_test_set:
f = open("data/context_biasing/LibriSpeechBiasingLists/all_rare_words.txt", 'r') f = open(
"data/context_biasing/LibriSpeechBiasingLists/all_rare_words.txt", "r"
)
biasing_words = f.read().strip().split() biasing_words = f.read().strip().split()
f.close() f.close()
else: else:
@ -747,7 +741,7 @@ def main():
model=model, model=model,
sp=sp, sp=sp,
word_table=word_table, word_table=word_table,
decoding_graph=decoding_graph, decoding_graph=decoding_graph,
) )
save_results( save_results(
@ -755,40 +749,41 @@ def main():
test_set_name=test_set, test_set_name=test_set,
results_dict=results_dict, results_dict=results_dict,
) )
if params.post_normalization: if params.post_normalization:
if "-post-normalization" not in params.suffix: if "-post-normalization" not in params.suffix:
params.suffix += "-post-normalization" params.suffix += "-post-normalization"
new_res = {} new_res = {}
for k in results_dict: for k in results_dict:
new_ans = [] new_ans = []
for item in results_dict[k]: for item in results_dict[k]:
id, ref, hyp = item id, ref, hyp = item
if params.use_ls_test_set: if params.use_ls_test_set:
hyp = " ".join(hyp).replace("-", " ").split() # handle the hypens hyp = (
" ".join(hyp).replace("-", " ").split()
) # handle the hypens
hyp = upper_only_alpha(" ".join(hyp)).split() hyp = upper_only_alpha(" ".join(hyp)).split()
hyp = [word_normalization(w.upper()) for w in hyp] hyp = [word_normalization(w.upper()) for w in hyp]
hyp = " ".join(hyp).split() hyp = " ".join(hyp).split()
hyp = [w for w in hyp if w != ""] hyp = [w for w in hyp if w != ""]
ref = upper_only_alpha(" ".join(ref)).split() ref = upper_only_alpha(" ".join(ref)).split()
else: else:
hyp = upper_only_alpha(" ".join(hyp)).split() hyp = upper_only_alpha(" ".join(hyp)).split()
ref = upper_only_alpha(" ".join(ref)).split() ref = upper_only_alpha(" ".join(ref)).split()
new_ans.append((id,ref,hyp)) new_ans.append((id, ref, hyp))
new_res[k] = new_ans new_res[k] = new_ans
save_results( save_results(
params=params, params=params,
test_set_name=test_set, test_set_name=test_set,
results_dict=new_res, results_dict=new_res,
biasing_words=biasing_words, biasing_words=biasing_words,
) )
if params.suffix.endswith("-post-normalization"): if params.suffix.endswith("-post-normalization"):
params.suffix = params.suffix.replace("-post-normalization", "") params.suffix = params.suffix.replace("-post-normalization", "")
logging.info("Done!") logging.info("Done!")