mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-13 02:54:19 +00:00
add context biasing at different levels
This commit is contained in:
parent
d4c5a1c157
commit
77890a6115
@ -57,6 +57,7 @@ from beam_search import (
|
||||
greedy_search_batch,
|
||||
modified_beam_search,
|
||||
)
|
||||
from ls_text_normalization import word_normalization
|
||||
from text_normalization import ref_text_normalization, remove_non_alphabetic, upper_only_alpha
|
||||
from train_baseline import (
|
||||
add_model_arguments,
|
||||
@ -76,8 +77,8 @@ from icefall.utils import (
|
||||
setup_logger,
|
||||
store_transcripts,
|
||||
str2bool,
|
||||
write_error_stats,
|
||||
)
|
||||
from utils import write_error_stats
|
||||
|
||||
LOG_EPS = math.log(1e-10)
|
||||
|
||||
@ -480,6 +481,7 @@ def save_results(
|
||||
params: AttributeDict,
|
||||
test_set_name: str,
|
||||
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
|
||||
biasing_words: List[str] = None,
|
||||
):
|
||||
test_set_wers = dict()
|
||||
test_set_cers = dict()
|
||||
@ -494,7 +496,7 @@ def save_results(
|
||||
errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt"
|
||||
with open(errs_filename, "w") as f:
|
||||
wer = write_error_stats(
|
||||
f, f"{test_set_name}-{key}", results, enable_log=True
|
||||
f, f"{test_set_name}-{key}", results, enable_log=True, biasing_words=biasing_words,
|
||||
)
|
||||
test_set_wers[key] = wer
|
||||
|
||||
@ -740,6 +742,12 @@ def main():
|
||||
test_dl = [long_audio_dl]
|
||||
|
||||
for test_set, test_dl in zip(test_sets, test_dl):
|
||||
if params.use_ls_test_set:
|
||||
f = open("data/context_biasing/LibriSpeechBiasingLists/all_rare_words.txt", 'r')
|
||||
biasing_words = f.read().strip().split()
|
||||
f.close()
|
||||
else:
|
||||
biasing_words = None
|
||||
results_dict = decode_dataset(
|
||||
dl=test_dl,
|
||||
params=params,
|
||||
@ -781,6 +789,7 @@ def main():
|
||||
params=params,
|
||||
test_set_name=test_set,
|
||||
results_dict=new_res,
|
||||
biasing_words=biasing_words,
|
||||
)
|
||||
|
||||
if params.suffix.endswith("-post-normalization"):
|
||||
|
@ -61,7 +61,7 @@ from beam_search import (
|
||||
modified_beam_search,
|
||||
)
|
||||
from dataset import naive_triplet_text_sampling, random_shuffle_subset
|
||||
from utils import get_facebook_biasing_list
|
||||
from utils import get_facebook_biasing_list, brian_biasing_list, write_error_stats
|
||||
from ls_text_normalization import word_normalization
|
||||
from text_normalization import (
|
||||
ref_text_normalization,
|
||||
@ -92,7 +92,6 @@ from icefall.utils import (
|
||||
setup_logger,
|
||||
store_transcripts,
|
||||
str2bool,
|
||||
write_error_stats,
|
||||
)
|
||||
|
||||
LOG_EPS = math.log(1e-10)
|
||||
@ -334,11 +333,18 @@ def get_parser():
|
||||
help="If use a fixed context list for LibriSpeech decoding"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--biasing-level",
|
||||
type=str,
|
||||
default="utterance",
|
||||
choices=["utterance", "Book", "Chapter"],
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--ls-distractors",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="If add distractors into context list for LibriSpeech decoding"
|
||||
type=int,
|
||||
default=0,
|
||||
help="The number of distractors into context list for LibriSpeech decoding"
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
@ -430,13 +436,20 @@ def decode_one_batch(
|
||||
pre_texts = ["" for _ in range(batch_size)]
|
||||
|
||||
if params.use_ls_context_list:
|
||||
pre_texts = [biasing_dict[id] for id in cut_ids]
|
||||
if params.biasing_level == "utterance":
|
||||
pre_texts = [biasing_dict[id] for id in cut_ids]
|
||||
elif params.biasing_level == "Chapter":
|
||||
chapter_ids = [c.split('-')[1] for c in cut_ids]
|
||||
pre_texts = [biasing_dict[id] for id in chapter_ids]
|
||||
elif params.biasing_level == "Book":
|
||||
chapter_ids = [c.split('-')[1] for c in cut_ids]
|
||||
pre_texts = [biasing_dict[id] for id in chapter_ids]
|
||||
if params.pre_text_transform == "mixed-punc":
|
||||
pre_texts = [t.lower() for t in pre_texts]
|
||||
|
||||
# get style_text
|
||||
if params.use_style_prompt:
|
||||
fixed_sentence = "Mixed-case English transcription, with punctuation. Actually, it is fully not related."
|
||||
fixed_sentence = "Mixed-case English transcription, with punctuation. Actually, it's fully not related."
|
||||
style_texts = batch["supervisions"].get("style_text", [fixed_sentence for _ in range(batch_size)])
|
||||
style_texts = [train_text_normalization(t) for t in style_texts]
|
||||
else:
|
||||
@ -447,7 +460,8 @@ def decode_one_batch(
|
||||
|
||||
# apply style transform to the pre_text and style_text
|
||||
pre_texts = _apply_style_transform(pre_texts, params.pre_text_transform)
|
||||
pre_texts = [t[:params.max_prompt_lens] for t in pre_texts]
|
||||
if not params.use_ls_context_list:
|
||||
pre_texts = [t[:params.max_prompt_lens] for t in pre_texts]
|
||||
#pre_texts = random_shuffle_subset(pre_texts, p=1.0, p_mask=0.0)
|
||||
if params.use_style_prompt:
|
||||
style_texts = _apply_style_transform(style_texts, params.style_text_transform)
|
||||
@ -461,7 +475,9 @@ def decode_one_batch(
|
||||
style_texts=style_texts,
|
||||
tokenizer=tokenizer,
|
||||
device=device,
|
||||
no_limit=True,
|
||||
)
|
||||
logging.info(f"Shape of the encoded prompts: {encoded_inputs['input_ids'].shape}")
|
||||
|
||||
memory, memory_key_padding_mask = model.encode_text(
|
||||
encoded_inputs=encoded_inputs,
|
||||
@ -666,6 +682,7 @@ def save_results(
|
||||
params: AttributeDict,
|
||||
test_set_name: str,
|
||||
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
|
||||
biasing_words: List[str] = None,
|
||||
):
|
||||
test_set_wers = dict()
|
||||
test_set_cers = dict()
|
||||
@ -792,9 +809,9 @@ def main():
|
||||
params.suffix += f"-use-context-fuser"
|
||||
|
||||
if params.use_ls_context_list:
|
||||
params.suffix += f"-use-ls-context-list"
|
||||
if params.ls_distractors:
|
||||
params.suffix += f"-add-ls-context-distractors"
|
||||
params.suffix += f"-use-{params.biasing_level}-level-ls-context-list"
|
||||
if params.biasing_level == "utterance" and params.ls_distractors:
|
||||
params.suffix += f"-ls-context-distractors-{params.ls_distractors}"
|
||||
|
||||
if params.use_averaged_model:
|
||||
params.suffix += "-use-averaged-model"
|
||||
@ -919,6 +936,7 @@ def main():
|
||||
ls_test_clean_cuts = libriheavy.librispeech_test_clean_cuts()
|
||||
ls_test_other_cuts = libriheavy.librispeech_test_other_cuts()
|
||||
long_audio_cuts = libriheavy.long_audio_cuts()
|
||||
|
||||
npr1_dev_cuts = libriheavy.npr1_dev_cuts()
|
||||
npr1_test_cuts = libriheavy.npr1_test_cuts()
|
||||
|
||||
|
@ -60,7 +60,8 @@ from beam_search import (
|
||||
modified_beam_search,
|
||||
)
|
||||
from dataset import naive_triplet_text_sampling, random_shuffle_subset, get_substring
|
||||
from utils import get_facebook_biasing_list
|
||||
from utils import get_facebook_biasing_list, brian_biasing_list, write_error_stats
|
||||
from ls_text_normalization import word_normalization
|
||||
from text_normalization import ref_text_normalization, remove_non_alphabetic, upper_only_alpha, upper_all_char, train_text_normalization
|
||||
from train_subformer_with_style import (
|
||||
add_model_arguments,
|
||||
@ -82,7 +83,6 @@ from icefall.utils import (
|
||||
setup_logger,
|
||||
store_transcripts,
|
||||
str2bool,
|
||||
write_error_stats,
|
||||
)
|
||||
|
||||
LOG_EPS = math.log(1e-10)
|
||||
@ -324,11 +324,18 @@ def get_parser():
|
||||
help="If use a fixed context list for LibriSpeech decoding"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--biasing-level",
|
||||
type=str,
|
||||
default="utterance",
|
||||
choices=["utterance", "Book", "Chapter"],
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--ls-distractors",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="If add distractors into context list for LibriSpeech decoding"
|
||||
type=int,
|
||||
default=0,
|
||||
help="The number of distractors into context list for LibriSpeech decoding"
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
@ -414,11 +421,19 @@ def decode_one_batch(
|
||||
|
||||
if "pre_text" in batch["supervisions"] and params.use_pre_text:
|
||||
pre_texts = batch["supervisions"]["pre_text"]
|
||||
pre_texts = [train_text_normalization(t) for t in pre_texts]
|
||||
else:
|
||||
pre_texts = ["" for _ in range(batch_size)]
|
||||
|
||||
if params.use_ls_context_list:
|
||||
pre_texts = [biasing_dict[id] for id in cut_ids]
|
||||
if params.biasing_level == "utterance":
|
||||
pre_texts = [biasing_dict[id] for id in cut_ids]
|
||||
elif params.biasing_level == "Chapter":
|
||||
chapter_ids = [c.split('-')[1] for c in cut_ids]
|
||||
pre_texts = [biasing_dict[id] for id in chapter_ids]
|
||||
elif params.biasing_level == "Book":
|
||||
chapter_ids = [c.split('-')[1] for c in cut_ids]
|
||||
pre_texts = [biasing_dict[id] for id in chapter_ids]
|
||||
if params.pre_text_transform == "mixed-punc":
|
||||
pre_texts = [t.lower() for t in pre_texts]
|
||||
|
||||
@ -434,7 +449,8 @@ def decode_one_batch(
|
||||
|
||||
# apply style transform to the pre_text and style_text
|
||||
pre_texts = _apply_style_transform(pre_texts, params.pre_text_transform)
|
||||
pre_texts = [t[:params.max_prompt_lens] for t in pre_texts]
|
||||
if not params.use_ls_context_list:
|
||||
pre_texts = [t[:params.max_prompt_lens] for t in pre_texts]
|
||||
#pre_texts = random_shuffle_subset(pre_texts, p=1.0, p_mask=0.0)
|
||||
if params.use_style_prompt:
|
||||
style_texts = _apply_style_transform(style_texts, params.style_text_transform)
|
||||
@ -448,8 +464,9 @@ def decode_one_batch(
|
||||
style_texts=style_texts,
|
||||
bpe_model=text_encoder_bpe_model,
|
||||
device=device,
|
||||
max_tokens=1000,
|
||||
max_tokens=8000,
|
||||
)
|
||||
logging.info(f"Shape of the encoded prompts: {pre_texts.shape}")
|
||||
|
||||
memory, memory_key_padding_mask = model.encode_text(
|
||||
text=pre_texts,
|
||||
@ -608,6 +625,13 @@ def decode_dataset(
|
||||
texts = _apply_style_transform(texts, params.style_text_transform)
|
||||
|
||||
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
||||
if not params.use_ls_test_set:
|
||||
try:
|
||||
book_names = [cut.text_path.split('/')[-2] for cut in batch["supervisions"]["cut"]]
|
||||
except:
|
||||
book_names = [cut.id.split('/')[0] for cut in batch["supervisions"]["cut"]]
|
||||
else:
|
||||
book_names = ["" for _ in cut_ids]
|
||||
|
||||
hyps_dict = decode_one_batch(
|
||||
params=params,
|
||||
@ -623,13 +647,14 @@ def decode_dataset(
|
||||
for name, hyps in hyps_dict.items():
|
||||
this_batch = []
|
||||
assert len(hyps) == len(texts)
|
||||
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
|
||||
for cut_id, book_name, hyp_words, ref_text in zip(cut_ids, book_names, hyps, texts):
|
||||
ref_text = ref_text_normalization(
|
||||
ref_text
|
||||
) # remove full-width symbols & some book marks
|
||||
ref_words = ref_text.split()
|
||||
this_batch.append((cut_id, ref_words, hyp_words))
|
||||
|
||||
# if not params.use_ls_test_set:
|
||||
# results[name + "_" + book_name].extend(this_batch)
|
||||
results[name].extend(this_batch)
|
||||
|
||||
num_cuts += len(texts)
|
||||
@ -647,6 +672,7 @@ def save_results(
|
||||
params: AttributeDict,
|
||||
test_set_name: str,
|
||||
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
|
||||
biasing_words: List[str] = None,
|
||||
):
|
||||
test_set_wers = dict()
|
||||
test_set_cers = dict()
|
||||
@ -661,7 +687,7 @@ def save_results(
|
||||
errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt"
|
||||
with open(errs_filename, "w") as f:
|
||||
wer = write_error_stats(
|
||||
f, f"{test_set_name}-{key}", results, enable_log=True
|
||||
f, f"{test_set_name}-{key}", results, enable_log=True, biasing_words=biasing_words,
|
||||
)
|
||||
test_set_wers[key] = wer
|
||||
|
||||
@ -769,9 +795,9 @@ def main():
|
||||
params.suffix += f"-use-context-fuser"
|
||||
|
||||
if params.use_ls_context_list:
|
||||
params.suffix += f"-use-ls-context-list"
|
||||
if params.ls_distractors:
|
||||
params.suffix += f"-add-ls-context-distractors"
|
||||
params.suffix += f"-use-{params.biasing_level}-level-ls-context-list"
|
||||
if params.biasing_level == "utterance" and params.ls_distractors:
|
||||
params.suffix += f"-ls-context-distractors-{params.ls_distractors}"
|
||||
|
||||
if params.use_averaged_model:
|
||||
params.suffix += "-use-averaged-model"
|
||||
@ -921,16 +947,28 @@ def main():
|
||||
test_sets = ["long-audio"]
|
||||
test_dl = [long_audio_dl]
|
||||
|
||||
#test_sets = ["npr1-dev", "npr1-test"]
|
||||
#test_dl = [npr1_dev_dl, npr1_test_dl]
|
||||
if params.long_audio_recog:
|
||||
test_sets = ["long-audio"]
|
||||
test_dl = [long_audio_dl]
|
||||
|
||||
for test_set, test_dl in zip(test_sets, test_dl):
|
||||
if test_set == "ls-test-clean":
|
||||
biasing_dict = get_facebook_biasing_list("test-clean", use_distractors=params.ls_distractors)
|
||||
elif test_set == "ls-test-other":
|
||||
biasing_dict = get_facebook_biasing_list("test-other", use_distractors=params.ls_distractors)
|
||||
if params.biasing_level == "utterance":
|
||||
if test_set == "ls-test-clean":
|
||||
biasing_dict = get_facebook_biasing_list("test-clean", use_distractors=params.ls_distractors, num_distractors=params.ls_distractors)
|
||||
elif test_set == "ls-test-other":
|
||||
biasing_dict = get_facebook_biasing_list("test-other", use_distractors=params.ls_distractors, num_distractors=params.ls_distractors)
|
||||
else:
|
||||
biasing_dict = None
|
||||
f = open("data/context_biasing/LibriSpeechBiasingLists/all_rare_words.txt", 'r')
|
||||
biasing_words = f.read().strip().split()
|
||||
f.close()
|
||||
|
||||
else:
|
||||
biasing_dict = None
|
||||
if params.use_ls_test_set:
|
||||
biasing_dict = brian_biasing_list(params.biasing_level)
|
||||
else:
|
||||
biasing_dict = None
|
||||
biasing_words = None
|
||||
|
||||
results_dict = decode_dataset(
|
||||
dl=test_dl,
|
||||
@ -961,7 +999,7 @@ def main():
|
||||
if params.use_ls_test_set:
|
||||
hyp = " ".join(hyp).replace("-", " ").split() # handle the hypens
|
||||
hyp = upper_only_alpha(" ".join(hyp)).split()
|
||||
hyp = [word_normalization(w.upper()) for w in hyp]
|
||||
hyp = [word_normalization(str(w).upper()) for w in hyp]
|
||||
hyp = " ".join(hyp).split()
|
||||
hyp = [w for w in hyp if w != ""]
|
||||
ref = upper_only_alpha(" ".join(ref)).split()
|
||||
@ -975,6 +1013,7 @@ def main():
|
||||
params=params,
|
||||
test_set_name=test_set,
|
||||
results_dict=new_res,
|
||||
biasing_words=biasing_words,
|
||||
)
|
||||
|
||||
if params.suffix.endswith("-post-normalization"):
|
||||
|
Loading…
x
Reference in New Issue
Block a user