From 17d0918969b94d56c4576381818b1351381dd58f Mon Sep 17 00:00:00 2001 From: marcoyang1998 Date: Wed, 16 Aug 2023 09:39:42 +0800 Subject: [PATCH] fix the post normalization bug, avoid multiple words --- .../zipformer_prompt_asr/decode_baseline.py | 10 ++--- .../decode_bert_with_style.py | 37 ++++++++++++++----- 2 files changed, 33 insertions(+), 14 deletions(-) diff --git a/egs/libriheavy/ASR/zipformer_prompt_asr/decode_baseline.py b/egs/libriheavy/ASR/zipformer_prompt_asr/decode_baseline.py index 17e97afbc..cf59e7503 100644 --- a/egs/libriheavy/ASR/zipformer_prompt_asr/decode_baseline.py +++ b/egs/libriheavy/ASR/zipformer_prompt_asr/decode_baseline.py @@ -57,7 +57,7 @@ from beam_search import ( greedy_search_batch, modified_beam_search, ) -from text_normalization import ref_text_normalization, remove_non_alphabetic +from text_normalization import ref_text_normalization, remove_non_alphabetic, upper_only_alpha from train_baseline import ( add_model_arguments, get_params, @@ -742,10 +742,8 @@ def main(): new_ans = [] for item in results_dict[k]: id, ref, hyp = item - hyp = [remove_non_alphabetic(w.upper(), strict=False) for w in hyp] - hyp = [w for w in hyp if w != ""] - ref = [remove_non_alphabetic(w.upper(), strict=False) for w in ref] - ref = [w for w in ref if w != ""] + hyp = upper_only_alpha(" ".join(hyp)).split() + ref = upper_only_alpha(" ".join(ref)).split() new_ans.append((id, ref, hyp)) new_res[k] = new_ans @@ -754,6 +752,8 @@ def main(): test_set_name=test_set, results_dict=new_res, ) + + params.suffix = params.suffix.replace("-post-normalization", "") logging.info("Done!") diff --git a/egs/libriheavy/ASR/zipformer_prompt_asr/decode_bert_with_style.py b/egs/libriheavy/ASR/zipformer_prompt_asr/decode_bert_with_style.py index 8608c1a2d..90203a954 100755 --- a/egs/libriheavy/ASR/zipformer_prompt_asr/decode_bert_with_style.py +++ b/egs/libriheavy/ASR/zipformer_prompt_asr/decode_bert_with_style.py @@ -62,7 +62,15 @@ from beam_search import ( ) from dataset import naive_triplet_text_sampling, random_shuffle_subset from utils import get_facebook_biasing_list -from text_normalization import ref_text_normalization, remove_non_alphabetic, upper_only_alpha, upper_all_char, lower_all_char, lower_only_alpha +from text_normalization import ( + ref_text_normalization, + remove_non_alphabetic, + upper_only_alpha, + upper_all_char, + lower_all_char, + lower_only_alpha, + train_text_normalization, +) from train_bert_encoder_with_style import ( add_model_arguments, get_params, @@ -313,6 +321,13 @@ def get_parser(): help="If use a fixed context list for LibriSpeech decoding" ) + parser.add_argument( + "--ls-distractors", + type=str2bool, + default=True, + help="If add distractors into context list for LibriSpeech decoding" + ) + add_model_arguments(parser) return parser @@ -393,20 +408,24 @@ def decode_one_batch( cuts = batch["supervisions"]["cut"] cut_ids = [c.supervisions[0].id for c in cuts] batch_size = feature.size(0) - + # get pre_text if "pre_text" in batch["supervisions"] and params.use_pre_text: pre_texts = batch["supervisions"]["pre_text"] + pre_texts = [train_text_normalization(t) for t in pre_texts] else: pre_texts = ["" for _ in range(batch_size)] if params.use_ls_context_list: pre_texts = [biasing_dict[id] for id in cut_ids] + if params.pre_text_transform == "mixed-punc": + pre_texts = [t.lower() for t in pre_texts] # get style_text if params.use_style_prompt: fixed_sentence = "Mixed-case English transcription, with punctuation. Actually, it is fully not related." style_texts = batch["supervisions"].get("style_text", [fixed_sentence for _ in range(batch_size)]) + style_texts = [train_text_normalization(t) for t in style_texts] else: style_texts = ["" for _ in range(batch_size)] # use empty string @@ -748,6 +767,8 @@ def main(): if params.use_ls_context_list: params.suffix += f"-use-ls-context-list" + if params.ls_distractors: + params.suffix += f"-add-ls-context-distractors" if params.use_averaged_model: params.suffix += "-use-averaged-model" @@ -886,9 +907,9 @@ def main(): for test_set, test_dl in zip(test_sets, test_dl): if test_set == "ls-test-clean": - biasing_dict = get_facebook_biasing_list("test-clean") + biasing_dict = get_facebook_biasing_list("test-clean", use_distractors=params.ls_distractors) elif test_set == "ls-test-other": - biasing_dict = get_facebook_biasing_list("test-other") + biasing_dict = get_facebook_biasing_list("test-other", use_distractors=params.ls_distractors) else: biasing_dict = None @@ -918,12 +939,10 @@ def main(): new_ans = [] for item in results_dict[k]: id, ref, hyp = item - hyp = [remove_non_alphabetic(w.upper(), strict=False) for w in hyp] - hyp = [w for w in hyp if w != ""] - ref = [remove_non_alphabetic(w.upper(), strict=False) for w in ref] - ref = [w for w in ref if w != ""] + hyp = upper_only_alpha(" ".join(hyp)).split() + ref = upper_only_alpha(" ".join(ref)).split() new_ans.append((id,ref,hyp)) - new_res[k] = new_ans + new_res[k] = new_ans save_results( params=params,