fix the post normalization bug, avoid multiple words

This commit is contained in:
marcoyang1998 2023-08-16 09:39:42 +08:00
parent fdc4fcabb9
commit 17d0918969
2 changed files with 33 additions and 14 deletions

View File

@ -57,7 +57,7 @@ from beam_search import (
greedy_search_batch,
modified_beam_search,
)
from text_normalization import ref_text_normalization, remove_non_alphabetic
from text_normalization import ref_text_normalization, remove_non_alphabetic, upper_only_alpha
from train_baseline import (
add_model_arguments,
get_params,
@ -742,10 +742,8 @@ def main():
new_ans = []
for item in results_dict[k]:
id, ref, hyp = item
hyp = [remove_non_alphabetic(w.upper(), strict=False) for w in hyp]
hyp = [w for w in hyp if w != ""]
ref = [remove_non_alphabetic(w.upper(), strict=False) for w in ref]
ref = [w for w in ref if w != ""]
hyp = upper_only_alpha(" ".join(hyp)).split()
ref = upper_only_alpha(" ".join(ref)).split()
new_ans.append((id, ref, hyp))
new_res[k] = new_ans
@ -754,6 +752,8 @@ def main():
test_set_name=test_set,
results_dict=new_res,
)
params.suffix = params.suffix.replace("-post-normalization", "")
logging.info("Done!")

View File

@ -62,7 +62,15 @@ from beam_search import (
)
from dataset import naive_triplet_text_sampling, random_shuffle_subset
from utils import get_facebook_biasing_list
from text_normalization import ref_text_normalization, remove_non_alphabetic, upper_only_alpha, upper_all_char, lower_all_char, lower_only_alpha
from text_normalization import (
ref_text_normalization,
remove_non_alphabetic,
upper_only_alpha,
upper_all_char,
lower_all_char,
lower_only_alpha,
train_text_normalization,
)
from train_bert_encoder_with_style import (
add_model_arguments,
get_params,
@ -313,6 +321,13 @@ def get_parser():
help="If use a fixed context list for LibriSpeech decoding"
)
parser.add_argument(
"--ls-distractors",
type=str2bool,
default=True,
help="If add distractors into context list for LibriSpeech decoding"
)
add_model_arguments(parser)
return parser
@ -393,20 +408,24 @@ def decode_one_batch(
cuts = batch["supervisions"]["cut"]
cut_ids = [c.supervisions[0].id for c in cuts]
batch_size = feature.size(0)
# get pre_text
if "pre_text" in batch["supervisions"] and params.use_pre_text:
pre_texts = batch["supervisions"]["pre_text"]
pre_texts = [train_text_normalization(t) for t in pre_texts]
else:
pre_texts = ["" for _ in range(batch_size)]
if params.use_ls_context_list:
pre_texts = [biasing_dict[id] for id in cut_ids]
if params.pre_text_transform == "mixed-punc":
pre_texts = [t.lower() for t in pre_texts]
# get style_text
if params.use_style_prompt:
fixed_sentence = "Mixed-case English transcription, with punctuation. Actually, it is fully not related."
style_texts = batch["supervisions"].get("style_text", [fixed_sentence for _ in range(batch_size)])
style_texts = [train_text_normalization(t) for t in style_texts]
else:
style_texts = ["" for _ in range(batch_size)] # use empty string
@ -748,6 +767,8 @@ def main():
if params.use_ls_context_list:
params.suffix += f"-use-ls-context-list"
if params.ls_distractors:
params.suffix += f"-add-ls-context-distractors"
if params.use_averaged_model:
params.suffix += "-use-averaged-model"
@ -886,9 +907,9 @@ def main():
for test_set, test_dl in zip(test_sets, test_dl):
if test_set == "ls-test-clean":
biasing_dict = get_facebook_biasing_list("test-clean")
biasing_dict = get_facebook_biasing_list("test-clean", use_distractors=params.ls_distractors)
elif test_set == "ls-test-other":
biasing_dict = get_facebook_biasing_list("test-other")
biasing_dict = get_facebook_biasing_list("test-other", use_distractors=params.ls_distractors)
else:
biasing_dict = None
@ -918,12 +939,10 @@ def main():
new_ans = []
for item in results_dict[k]:
id, ref, hyp = item
hyp = [remove_non_alphabetic(w.upper(), strict=False) for w in hyp]
hyp = [w for w in hyp if w != ""]
ref = [remove_non_alphabetic(w.upper(), strict=False) for w in ref]
ref = [w for w in ref if w != ""]
hyp = upper_only_alpha(" ".join(hyp)).split()
ref = upper_only_alpha(" ".join(ref)).split()
new_ans.append((id,ref,hyp))
new_res[k] = new_ans
new_res[k] = new_ans
save_results(
params=params,