mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-13 19:14:20 +00:00
add context biasing at different levels
This commit is contained in:
parent
d4c5a1c157
commit
77890a6115
@ -57,6 +57,7 @@ from beam_search import (
|
|||||||
greedy_search_batch,
|
greedy_search_batch,
|
||||||
modified_beam_search,
|
modified_beam_search,
|
||||||
)
|
)
|
||||||
|
from ls_text_normalization import word_normalization
|
||||||
from text_normalization import ref_text_normalization, remove_non_alphabetic, upper_only_alpha
|
from text_normalization import ref_text_normalization, remove_non_alphabetic, upper_only_alpha
|
||||||
from train_baseline import (
|
from train_baseline import (
|
||||||
add_model_arguments,
|
add_model_arguments,
|
||||||
@ -76,8 +77,8 @@ from icefall.utils import (
|
|||||||
setup_logger,
|
setup_logger,
|
||||||
store_transcripts,
|
store_transcripts,
|
||||||
str2bool,
|
str2bool,
|
||||||
write_error_stats,
|
|
||||||
)
|
)
|
||||||
|
from utils import write_error_stats
|
||||||
|
|
||||||
LOG_EPS = math.log(1e-10)
|
LOG_EPS = math.log(1e-10)
|
||||||
|
|
||||||
@ -480,6 +481,7 @@ def save_results(
|
|||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
test_set_name: str,
|
test_set_name: str,
|
||||||
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
|
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
|
||||||
|
biasing_words: List[str] = None,
|
||||||
):
|
):
|
||||||
test_set_wers = dict()
|
test_set_wers = dict()
|
||||||
test_set_cers = dict()
|
test_set_cers = dict()
|
||||||
@ -494,7 +496,7 @@ def save_results(
|
|||||||
errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt"
|
errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt"
|
||||||
with open(errs_filename, "w") as f:
|
with open(errs_filename, "w") as f:
|
||||||
wer = write_error_stats(
|
wer = write_error_stats(
|
||||||
f, f"{test_set_name}-{key}", results, enable_log=True
|
f, f"{test_set_name}-{key}", results, enable_log=True, biasing_words=biasing_words,
|
||||||
)
|
)
|
||||||
test_set_wers[key] = wer
|
test_set_wers[key] = wer
|
||||||
|
|
||||||
@ -740,6 +742,12 @@ def main():
|
|||||||
test_dl = [long_audio_dl]
|
test_dl = [long_audio_dl]
|
||||||
|
|
||||||
for test_set, test_dl in zip(test_sets, test_dl):
|
for test_set, test_dl in zip(test_sets, test_dl):
|
||||||
|
if params.use_ls_test_set:
|
||||||
|
f = open("data/context_biasing/LibriSpeechBiasingLists/all_rare_words.txt", 'r')
|
||||||
|
biasing_words = f.read().strip().split()
|
||||||
|
f.close()
|
||||||
|
else:
|
||||||
|
biasing_words = None
|
||||||
results_dict = decode_dataset(
|
results_dict = decode_dataset(
|
||||||
dl=test_dl,
|
dl=test_dl,
|
||||||
params=params,
|
params=params,
|
||||||
@ -781,6 +789,7 @@ def main():
|
|||||||
params=params,
|
params=params,
|
||||||
test_set_name=test_set,
|
test_set_name=test_set,
|
||||||
results_dict=new_res,
|
results_dict=new_res,
|
||||||
|
biasing_words=biasing_words,
|
||||||
)
|
)
|
||||||
|
|
||||||
if params.suffix.endswith("-post-normalization"):
|
if params.suffix.endswith("-post-normalization"):
|
||||||
|
@ -61,7 +61,7 @@ from beam_search import (
|
|||||||
modified_beam_search,
|
modified_beam_search,
|
||||||
)
|
)
|
||||||
from dataset import naive_triplet_text_sampling, random_shuffle_subset
|
from dataset import naive_triplet_text_sampling, random_shuffle_subset
|
||||||
from utils import get_facebook_biasing_list
|
from utils import get_facebook_biasing_list, brian_biasing_list, write_error_stats
|
||||||
from ls_text_normalization import word_normalization
|
from ls_text_normalization import word_normalization
|
||||||
from text_normalization import (
|
from text_normalization import (
|
||||||
ref_text_normalization,
|
ref_text_normalization,
|
||||||
@ -92,7 +92,6 @@ from icefall.utils import (
|
|||||||
setup_logger,
|
setup_logger,
|
||||||
store_transcripts,
|
store_transcripts,
|
||||||
str2bool,
|
str2bool,
|
||||||
write_error_stats,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
LOG_EPS = math.log(1e-10)
|
LOG_EPS = math.log(1e-10)
|
||||||
@ -334,11 +333,18 @@ def get_parser():
|
|||||||
help="If use a fixed context list for LibriSpeech decoding"
|
help="If use a fixed context list for LibriSpeech decoding"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--biasing-level",
|
||||||
|
type=str,
|
||||||
|
default="utterance",
|
||||||
|
choices=["utterance", "Book", "Chapter"],
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--ls-distractors",
|
"--ls-distractors",
|
||||||
type=str2bool,
|
type=int,
|
||||||
default=True,
|
default=0,
|
||||||
help="If add distractors into context list for LibriSpeech decoding"
|
help="The number of distractors into context list for LibriSpeech decoding"
|
||||||
)
|
)
|
||||||
|
|
||||||
add_model_arguments(parser)
|
add_model_arguments(parser)
|
||||||
@ -430,13 +436,20 @@ def decode_one_batch(
|
|||||||
pre_texts = ["" for _ in range(batch_size)]
|
pre_texts = ["" for _ in range(batch_size)]
|
||||||
|
|
||||||
if params.use_ls_context_list:
|
if params.use_ls_context_list:
|
||||||
pre_texts = [biasing_dict[id] for id in cut_ids]
|
if params.biasing_level == "utterance":
|
||||||
|
pre_texts = [biasing_dict[id] for id in cut_ids]
|
||||||
|
elif params.biasing_level == "Chapter":
|
||||||
|
chapter_ids = [c.split('-')[1] for c in cut_ids]
|
||||||
|
pre_texts = [biasing_dict[id] for id in chapter_ids]
|
||||||
|
elif params.biasing_level == "Book":
|
||||||
|
chapter_ids = [c.split('-')[1] for c in cut_ids]
|
||||||
|
pre_texts = [biasing_dict[id] for id in chapter_ids]
|
||||||
if params.pre_text_transform == "mixed-punc":
|
if params.pre_text_transform == "mixed-punc":
|
||||||
pre_texts = [t.lower() for t in pre_texts]
|
pre_texts = [t.lower() for t in pre_texts]
|
||||||
|
|
||||||
# get style_text
|
# get style_text
|
||||||
if params.use_style_prompt:
|
if params.use_style_prompt:
|
||||||
fixed_sentence = "Mixed-case English transcription, with punctuation. Actually, it is fully not related."
|
fixed_sentence = "Mixed-case English transcription, with punctuation. Actually, it's fully not related."
|
||||||
style_texts = batch["supervisions"].get("style_text", [fixed_sentence for _ in range(batch_size)])
|
style_texts = batch["supervisions"].get("style_text", [fixed_sentence for _ in range(batch_size)])
|
||||||
style_texts = [train_text_normalization(t) for t in style_texts]
|
style_texts = [train_text_normalization(t) for t in style_texts]
|
||||||
else:
|
else:
|
||||||
@ -447,7 +460,8 @@ def decode_one_batch(
|
|||||||
|
|
||||||
# apply style transform to the pre_text and style_text
|
# apply style transform to the pre_text and style_text
|
||||||
pre_texts = _apply_style_transform(pre_texts, params.pre_text_transform)
|
pre_texts = _apply_style_transform(pre_texts, params.pre_text_transform)
|
||||||
pre_texts = [t[:params.max_prompt_lens] for t in pre_texts]
|
if not params.use_ls_context_list:
|
||||||
|
pre_texts = [t[:params.max_prompt_lens] for t in pre_texts]
|
||||||
#pre_texts = random_shuffle_subset(pre_texts, p=1.0, p_mask=0.0)
|
#pre_texts = random_shuffle_subset(pre_texts, p=1.0, p_mask=0.0)
|
||||||
if params.use_style_prompt:
|
if params.use_style_prompt:
|
||||||
style_texts = _apply_style_transform(style_texts, params.style_text_transform)
|
style_texts = _apply_style_transform(style_texts, params.style_text_transform)
|
||||||
@ -461,7 +475,9 @@ def decode_one_batch(
|
|||||||
style_texts=style_texts,
|
style_texts=style_texts,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
device=device,
|
device=device,
|
||||||
|
no_limit=True,
|
||||||
)
|
)
|
||||||
|
logging.info(f"Shape of the encoded prompts: {encoded_inputs['input_ids'].shape}")
|
||||||
|
|
||||||
memory, memory_key_padding_mask = model.encode_text(
|
memory, memory_key_padding_mask = model.encode_text(
|
||||||
encoded_inputs=encoded_inputs,
|
encoded_inputs=encoded_inputs,
|
||||||
@ -666,6 +682,7 @@ def save_results(
|
|||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
test_set_name: str,
|
test_set_name: str,
|
||||||
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
|
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
|
||||||
|
biasing_words: List[str] = None,
|
||||||
):
|
):
|
||||||
test_set_wers = dict()
|
test_set_wers = dict()
|
||||||
test_set_cers = dict()
|
test_set_cers = dict()
|
||||||
@ -792,9 +809,9 @@ def main():
|
|||||||
params.suffix += f"-use-context-fuser"
|
params.suffix += f"-use-context-fuser"
|
||||||
|
|
||||||
if params.use_ls_context_list:
|
if params.use_ls_context_list:
|
||||||
params.suffix += f"-use-ls-context-list"
|
params.suffix += f"-use-{params.biasing_level}-level-ls-context-list"
|
||||||
if params.ls_distractors:
|
if params.biasing_level == "utterance" and params.ls_distractors:
|
||||||
params.suffix += f"-add-ls-context-distractors"
|
params.suffix += f"-ls-context-distractors-{params.ls_distractors}"
|
||||||
|
|
||||||
if params.use_averaged_model:
|
if params.use_averaged_model:
|
||||||
params.suffix += "-use-averaged-model"
|
params.suffix += "-use-averaged-model"
|
||||||
@ -919,6 +936,7 @@ def main():
|
|||||||
ls_test_clean_cuts = libriheavy.librispeech_test_clean_cuts()
|
ls_test_clean_cuts = libriheavy.librispeech_test_clean_cuts()
|
||||||
ls_test_other_cuts = libriheavy.librispeech_test_other_cuts()
|
ls_test_other_cuts = libriheavy.librispeech_test_other_cuts()
|
||||||
long_audio_cuts = libriheavy.long_audio_cuts()
|
long_audio_cuts = libriheavy.long_audio_cuts()
|
||||||
|
|
||||||
npr1_dev_cuts = libriheavy.npr1_dev_cuts()
|
npr1_dev_cuts = libriheavy.npr1_dev_cuts()
|
||||||
npr1_test_cuts = libriheavy.npr1_test_cuts()
|
npr1_test_cuts = libriheavy.npr1_test_cuts()
|
||||||
|
|
||||||
|
@ -60,7 +60,8 @@ from beam_search import (
|
|||||||
modified_beam_search,
|
modified_beam_search,
|
||||||
)
|
)
|
||||||
from dataset import naive_triplet_text_sampling, random_shuffle_subset, get_substring
|
from dataset import naive_triplet_text_sampling, random_shuffle_subset, get_substring
|
||||||
from utils import get_facebook_biasing_list
|
from utils import get_facebook_biasing_list, brian_biasing_list, write_error_stats
|
||||||
|
from ls_text_normalization import word_normalization
|
||||||
from text_normalization import ref_text_normalization, remove_non_alphabetic, upper_only_alpha, upper_all_char, train_text_normalization
|
from text_normalization import ref_text_normalization, remove_non_alphabetic, upper_only_alpha, upper_all_char, train_text_normalization
|
||||||
from train_subformer_with_style import (
|
from train_subformer_with_style import (
|
||||||
add_model_arguments,
|
add_model_arguments,
|
||||||
@ -82,7 +83,6 @@ from icefall.utils import (
|
|||||||
setup_logger,
|
setup_logger,
|
||||||
store_transcripts,
|
store_transcripts,
|
||||||
str2bool,
|
str2bool,
|
||||||
write_error_stats,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
LOG_EPS = math.log(1e-10)
|
LOG_EPS = math.log(1e-10)
|
||||||
@ -324,11 +324,18 @@ def get_parser():
|
|||||||
help="If use a fixed context list for LibriSpeech decoding"
|
help="If use a fixed context list for LibriSpeech decoding"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--biasing-level",
|
||||||
|
type=str,
|
||||||
|
default="utterance",
|
||||||
|
choices=["utterance", "Book", "Chapter"],
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--ls-distractors",
|
"--ls-distractors",
|
||||||
type=str2bool,
|
type=int,
|
||||||
default=True,
|
default=0,
|
||||||
help="If add distractors into context list for LibriSpeech decoding"
|
help="The number of distractors into context list for LibriSpeech decoding"
|
||||||
)
|
)
|
||||||
|
|
||||||
add_model_arguments(parser)
|
add_model_arguments(parser)
|
||||||
@ -414,11 +421,19 @@ def decode_one_batch(
|
|||||||
|
|
||||||
if "pre_text" in batch["supervisions"] and params.use_pre_text:
|
if "pre_text" in batch["supervisions"] and params.use_pre_text:
|
||||||
pre_texts = batch["supervisions"]["pre_text"]
|
pre_texts = batch["supervisions"]["pre_text"]
|
||||||
|
pre_texts = [train_text_normalization(t) for t in pre_texts]
|
||||||
else:
|
else:
|
||||||
pre_texts = ["" for _ in range(batch_size)]
|
pre_texts = ["" for _ in range(batch_size)]
|
||||||
|
|
||||||
if params.use_ls_context_list:
|
if params.use_ls_context_list:
|
||||||
pre_texts = [biasing_dict[id] for id in cut_ids]
|
if params.biasing_level == "utterance":
|
||||||
|
pre_texts = [biasing_dict[id] for id in cut_ids]
|
||||||
|
elif params.biasing_level == "Chapter":
|
||||||
|
chapter_ids = [c.split('-')[1] for c in cut_ids]
|
||||||
|
pre_texts = [biasing_dict[id] for id in chapter_ids]
|
||||||
|
elif params.biasing_level == "Book":
|
||||||
|
chapter_ids = [c.split('-')[1] for c in cut_ids]
|
||||||
|
pre_texts = [biasing_dict[id] for id in chapter_ids]
|
||||||
if params.pre_text_transform == "mixed-punc":
|
if params.pre_text_transform == "mixed-punc":
|
||||||
pre_texts = [t.lower() for t in pre_texts]
|
pre_texts = [t.lower() for t in pre_texts]
|
||||||
|
|
||||||
@ -434,7 +449,8 @@ def decode_one_batch(
|
|||||||
|
|
||||||
# apply style transform to the pre_text and style_text
|
# apply style transform to the pre_text and style_text
|
||||||
pre_texts = _apply_style_transform(pre_texts, params.pre_text_transform)
|
pre_texts = _apply_style_transform(pre_texts, params.pre_text_transform)
|
||||||
pre_texts = [t[:params.max_prompt_lens] for t in pre_texts]
|
if not params.use_ls_context_list:
|
||||||
|
pre_texts = [t[:params.max_prompt_lens] for t in pre_texts]
|
||||||
#pre_texts = random_shuffle_subset(pre_texts, p=1.0, p_mask=0.0)
|
#pre_texts = random_shuffle_subset(pre_texts, p=1.0, p_mask=0.0)
|
||||||
if params.use_style_prompt:
|
if params.use_style_prompt:
|
||||||
style_texts = _apply_style_transform(style_texts, params.style_text_transform)
|
style_texts = _apply_style_transform(style_texts, params.style_text_transform)
|
||||||
@ -448,8 +464,9 @@ def decode_one_batch(
|
|||||||
style_texts=style_texts,
|
style_texts=style_texts,
|
||||||
bpe_model=text_encoder_bpe_model,
|
bpe_model=text_encoder_bpe_model,
|
||||||
device=device,
|
device=device,
|
||||||
max_tokens=1000,
|
max_tokens=8000,
|
||||||
)
|
)
|
||||||
|
logging.info(f"Shape of the encoded prompts: {pre_texts.shape}")
|
||||||
|
|
||||||
memory, memory_key_padding_mask = model.encode_text(
|
memory, memory_key_padding_mask = model.encode_text(
|
||||||
text=pre_texts,
|
text=pre_texts,
|
||||||
@ -608,6 +625,13 @@ def decode_dataset(
|
|||||||
texts = _apply_style_transform(texts, params.style_text_transform)
|
texts = _apply_style_transform(texts, params.style_text_transform)
|
||||||
|
|
||||||
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
||||||
|
if not params.use_ls_test_set:
|
||||||
|
try:
|
||||||
|
book_names = [cut.text_path.split('/')[-2] for cut in batch["supervisions"]["cut"]]
|
||||||
|
except:
|
||||||
|
book_names = [cut.id.split('/')[0] for cut in batch["supervisions"]["cut"]]
|
||||||
|
else:
|
||||||
|
book_names = ["" for _ in cut_ids]
|
||||||
|
|
||||||
hyps_dict = decode_one_batch(
|
hyps_dict = decode_one_batch(
|
||||||
params=params,
|
params=params,
|
||||||
@ -623,13 +647,14 @@ def decode_dataset(
|
|||||||
for name, hyps in hyps_dict.items():
|
for name, hyps in hyps_dict.items():
|
||||||
this_batch = []
|
this_batch = []
|
||||||
assert len(hyps) == len(texts)
|
assert len(hyps) == len(texts)
|
||||||
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
|
for cut_id, book_name, hyp_words, ref_text in zip(cut_ids, book_names, hyps, texts):
|
||||||
ref_text = ref_text_normalization(
|
ref_text = ref_text_normalization(
|
||||||
ref_text
|
ref_text
|
||||||
) # remove full-width symbols & some book marks
|
) # remove full-width symbols & some book marks
|
||||||
ref_words = ref_text.split()
|
ref_words = ref_text.split()
|
||||||
this_batch.append((cut_id, ref_words, hyp_words))
|
this_batch.append((cut_id, ref_words, hyp_words))
|
||||||
|
# if not params.use_ls_test_set:
|
||||||
|
# results[name + "_" + book_name].extend(this_batch)
|
||||||
results[name].extend(this_batch)
|
results[name].extend(this_batch)
|
||||||
|
|
||||||
num_cuts += len(texts)
|
num_cuts += len(texts)
|
||||||
@ -647,6 +672,7 @@ def save_results(
|
|||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
test_set_name: str,
|
test_set_name: str,
|
||||||
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
|
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
|
||||||
|
biasing_words: List[str] = None,
|
||||||
):
|
):
|
||||||
test_set_wers = dict()
|
test_set_wers = dict()
|
||||||
test_set_cers = dict()
|
test_set_cers = dict()
|
||||||
@ -661,7 +687,7 @@ def save_results(
|
|||||||
errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt"
|
errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt"
|
||||||
with open(errs_filename, "w") as f:
|
with open(errs_filename, "w") as f:
|
||||||
wer = write_error_stats(
|
wer = write_error_stats(
|
||||||
f, f"{test_set_name}-{key}", results, enable_log=True
|
f, f"{test_set_name}-{key}", results, enable_log=True, biasing_words=biasing_words,
|
||||||
)
|
)
|
||||||
test_set_wers[key] = wer
|
test_set_wers[key] = wer
|
||||||
|
|
||||||
@ -769,9 +795,9 @@ def main():
|
|||||||
params.suffix += f"-use-context-fuser"
|
params.suffix += f"-use-context-fuser"
|
||||||
|
|
||||||
if params.use_ls_context_list:
|
if params.use_ls_context_list:
|
||||||
params.suffix += f"-use-ls-context-list"
|
params.suffix += f"-use-{params.biasing_level}-level-ls-context-list"
|
||||||
if params.ls_distractors:
|
if params.biasing_level == "utterance" and params.ls_distractors:
|
||||||
params.suffix += f"-add-ls-context-distractors"
|
params.suffix += f"-ls-context-distractors-{params.ls_distractors}"
|
||||||
|
|
||||||
if params.use_averaged_model:
|
if params.use_averaged_model:
|
||||||
params.suffix += "-use-averaged-model"
|
params.suffix += "-use-averaged-model"
|
||||||
@ -921,16 +947,28 @@ def main():
|
|||||||
test_sets = ["long-audio"]
|
test_sets = ["long-audio"]
|
||||||
test_dl = [long_audio_dl]
|
test_dl = [long_audio_dl]
|
||||||
|
|
||||||
#test_sets = ["npr1-dev", "npr1-test"]
|
if params.long_audio_recog:
|
||||||
#test_dl = [npr1_dev_dl, npr1_test_dl]
|
test_sets = ["long-audio"]
|
||||||
|
test_dl = [long_audio_dl]
|
||||||
|
|
||||||
for test_set, test_dl in zip(test_sets, test_dl):
|
for test_set, test_dl in zip(test_sets, test_dl):
|
||||||
if test_set == "ls-test-clean":
|
if params.biasing_level == "utterance":
|
||||||
biasing_dict = get_facebook_biasing_list("test-clean", use_distractors=params.ls_distractors)
|
if test_set == "ls-test-clean":
|
||||||
elif test_set == "ls-test-other":
|
biasing_dict = get_facebook_biasing_list("test-clean", use_distractors=params.ls_distractors, num_distractors=params.ls_distractors)
|
||||||
biasing_dict = get_facebook_biasing_list("test-other", use_distractors=params.ls_distractors)
|
elif test_set == "ls-test-other":
|
||||||
|
biasing_dict = get_facebook_biasing_list("test-other", use_distractors=params.ls_distractors, num_distractors=params.ls_distractors)
|
||||||
|
else:
|
||||||
|
biasing_dict = None
|
||||||
|
f = open("data/context_biasing/LibriSpeechBiasingLists/all_rare_words.txt", 'r')
|
||||||
|
biasing_words = f.read().strip().split()
|
||||||
|
f.close()
|
||||||
|
|
||||||
else:
|
else:
|
||||||
biasing_dict = None
|
if params.use_ls_test_set:
|
||||||
|
biasing_dict = brian_biasing_list(params.biasing_level)
|
||||||
|
else:
|
||||||
|
biasing_dict = None
|
||||||
|
biasing_words = None
|
||||||
|
|
||||||
results_dict = decode_dataset(
|
results_dict = decode_dataset(
|
||||||
dl=test_dl,
|
dl=test_dl,
|
||||||
@ -961,7 +999,7 @@ def main():
|
|||||||
if params.use_ls_test_set:
|
if params.use_ls_test_set:
|
||||||
hyp = " ".join(hyp).replace("-", " ").split() # handle the hypens
|
hyp = " ".join(hyp).replace("-", " ").split() # handle the hypens
|
||||||
hyp = upper_only_alpha(" ".join(hyp)).split()
|
hyp = upper_only_alpha(" ".join(hyp)).split()
|
||||||
hyp = [word_normalization(w.upper()) for w in hyp]
|
hyp = [word_normalization(str(w).upper()) for w in hyp]
|
||||||
hyp = " ".join(hyp).split()
|
hyp = " ".join(hyp).split()
|
||||||
hyp = [w for w in hyp if w != ""]
|
hyp = [w for w in hyp if w != ""]
|
||||||
ref = upper_only_alpha(" ".join(ref)).split()
|
ref = upper_only_alpha(" ".join(ref)).split()
|
||||||
@ -975,6 +1013,7 @@ def main():
|
|||||||
params=params,
|
params=params,
|
||||||
test_set_name=test_set,
|
test_set_name=test_set,
|
||||||
results_dict=new_res,
|
results_dict=new_res,
|
||||||
|
biasing_words=biasing_words,
|
||||||
)
|
)
|
||||||
|
|
||||||
if params.suffix.endswith("-post-normalization"):
|
if params.suffix.endswith("-post-normalization"):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user