mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-08 08:34:19 +00:00
fix the post normalization bug, avoid multiple words
This commit is contained in:
parent
fdc4fcabb9
commit
17d0918969
@ -57,7 +57,7 @@ from beam_search import (
|
||||
greedy_search_batch,
|
||||
modified_beam_search,
|
||||
)
|
||||
from text_normalization import ref_text_normalization, remove_non_alphabetic
|
||||
from text_normalization import ref_text_normalization, remove_non_alphabetic, upper_only_alpha
|
||||
from train_baseline import (
|
||||
add_model_arguments,
|
||||
get_params,
|
||||
@ -742,10 +742,8 @@ def main():
|
||||
new_ans = []
|
||||
for item in results_dict[k]:
|
||||
id, ref, hyp = item
|
||||
hyp = [remove_non_alphabetic(w.upper(), strict=False) for w in hyp]
|
||||
hyp = [w for w in hyp if w != ""]
|
||||
ref = [remove_non_alphabetic(w.upper(), strict=False) for w in ref]
|
||||
ref = [w for w in ref if w != ""]
|
||||
hyp = upper_only_alpha(" ".join(hyp)).split()
|
||||
ref = upper_only_alpha(" ".join(ref)).split()
|
||||
new_ans.append((id, ref, hyp))
|
||||
new_res[k] = new_ans
|
||||
|
||||
@ -754,6 +752,8 @@ def main():
|
||||
test_set_name=test_set,
|
||||
results_dict=new_res,
|
||||
)
|
||||
|
||||
params.suffix = params.suffix.replace("-post-normalization", "")
|
||||
|
||||
logging.info("Done!")
|
||||
|
||||
|
@ -62,7 +62,15 @@ from beam_search import (
|
||||
)
|
||||
from dataset import naive_triplet_text_sampling, random_shuffle_subset
|
||||
from utils import get_facebook_biasing_list
|
||||
from text_normalization import ref_text_normalization, remove_non_alphabetic, upper_only_alpha, upper_all_char, lower_all_char, lower_only_alpha
|
||||
from text_normalization import (
|
||||
ref_text_normalization,
|
||||
remove_non_alphabetic,
|
||||
upper_only_alpha,
|
||||
upper_all_char,
|
||||
lower_all_char,
|
||||
lower_only_alpha,
|
||||
train_text_normalization,
|
||||
)
|
||||
from train_bert_encoder_with_style import (
|
||||
add_model_arguments,
|
||||
get_params,
|
||||
@ -313,6 +321,13 @@ def get_parser():
|
||||
help="If use a fixed context list for LibriSpeech decoding"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--ls-distractors",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="If add distractors into context list for LibriSpeech decoding"
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
@ -393,20 +408,24 @@ def decode_one_batch(
|
||||
cuts = batch["supervisions"]["cut"]
|
||||
cut_ids = [c.supervisions[0].id for c in cuts]
|
||||
batch_size = feature.size(0)
|
||||
|
||||
|
||||
# get pre_text
|
||||
if "pre_text" in batch["supervisions"] and params.use_pre_text:
|
||||
pre_texts = batch["supervisions"]["pre_text"]
|
||||
pre_texts = [train_text_normalization(t) for t in pre_texts]
|
||||
else:
|
||||
pre_texts = ["" for _ in range(batch_size)]
|
||||
|
||||
if params.use_ls_context_list:
|
||||
pre_texts = [biasing_dict[id] for id in cut_ids]
|
||||
if params.pre_text_transform == "mixed-punc":
|
||||
pre_texts = [t.lower() for t in pre_texts]
|
||||
|
||||
# get style_text
|
||||
if params.use_style_prompt:
|
||||
fixed_sentence = "Mixed-case English transcription, with punctuation. Actually, it is fully not related."
|
||||
style_texts = batch["supervisions"].get("style_text", [fixed_sentence for _ in range(batch_size)])
|
||||
style_texts = [train_text_normalization(t) for t in style_texts]
|
||||
else:
|
||||
style_texts = ["" for _ in range(batch_size)] # use empty string
|
||||
|
||||
@ -748,6 +767,8 @@ def main():
|
||||
|
||||
if params.use_ls_context_list:
|
||||
params.suffix += f"-use-ls-context-list"
|
||||
if params.ls_distractors:
|
||||
params.suffix += f"-add-ls-context-distractors"
|
||||
|
||||
if params.use_averaged_model:
|
||||
params.suffix += "-use-averaged-model"
|
||||
@ -886,9 +907,9 @@ def main():
|
||||
|
||||
for test_set, test_dl in zip(test_sets, test_dl):
|
||||
if test_set == "ls-test-clean":
|
||||
biasing_dict = get_facebook_biasing_list("test-clean")
|
||||
biasing_dict = get_facebook_biasing_list("test-clean", use_distractors=params.ls_distractors)
|
||||
elif test_set == "ls-test-other":
|
||||
biasing_dict = get_facebook_biasing_list("test-other")
|
||||
biasing_dict = get_facebook_biasing_list("test-other", use_distractors=params.ls_distractors)
|
||||
else:
|
||||
biasing_dict = None
|
||||
|
||||
@ -918,12 +939,10 @@ def main():
|
||||
new_ans = []
|
||||
for item in results_dict[k]:
|
||||
id, ref, hyp = item
|
||||
hyp = [remove_non_alphabetic(w.upper(), strict=False) for w in hyp]
|
||||
hyp = [w for w in hyp if w != ""]
|
||||
ref = [remove_non_alphabetic(w.upper(), strict=False) for w in ref]
|
||||
ref = [w for w in ref if w != ""]
|
||||
hyp = upper_only_alpha(" ".join(hyp)).split()
|
||||
ref = upper_only_alpha(" ".join(ref)).split()
|
||||
new_ans.append((id,ref,hyp))
|
||||
new_res[k] = new_ans
|
||||
new_res[k] = new_ans
|
||||
|
||||
save_results(
|
||||
params=params,
|
||||
|
Loading…
x
Reference in New Issue
Block a user