add context biasing at different levels

This commit is contained in:
marcoyang1998 2023-09-08 09:56:45 +08:00
parent d4c5a1c157
commit 77890a6115
3 changed files with 101 additions and 35 deletions

View File

@ -57,6 +57,7 @@ from beam_search import (
greedy_search_batch,
modified_beam_search,
)
from ls_text_normalization import word_normalization
from text_normalization import ref_text_normalization, remove_non_alphabetic, upper_only_alpha
from train_baseline import (
add_model_arguments,
@ -76,8 +77,8 @@ from icefall.utils import (
setup_logger,
store_transcripts,
str2bool,
write_error_stats,
)
from utils import write_error_stats
LOG_EPS = math.log(1e-10)
@ -480,6 +481,7 @@ def save_results(
params: AttributeDict,
test_set_name: str,
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
biasing_words: List[str] = None,
):
test_set_wers = dict()
test_set_cers = dict()
@ -494,7 +496,7 @@ def save_results(
errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt"
with open(errs_filename, "w") as f:
wer = write_error_stats(
f, f"{test_set_name}-{key}", results, enable_log=True
f, f"{test_set_name}-{key}", results, enable_log=True, biasing_words=biasing_words,
)
test_set_wers[key] = wer
@ -740,6 +742,12 @@ def main():
test_dl = [long_audio_dl]
for test_set, test_dl in zip(test_sets, test_dl):
if params.use_ls_test_set:
f = open("data/context_biasing/LibriSpeechBiasingLists/all_rare_words.txt", 'r')
biasing_words = f.read().strip().split()
f.close()
else:
biasing_words = None
results_dict = decode_dataset(
dl=test_dl,
params=params,
@ -781,6 +789,7 @@ def main():
params=params,
test_set_name=test_set,
results_dict=new_res,
biasing_words=biasing_words,
)
if params.suffix.endswith("-post-normalization"):

View File

@ -61,7 +61,7 @@ from beam_search import (
modified_beam_search,
)
from dataset import naive_triplet_text_sampling, random_shuffle_subset
from utils import get_facebook_biasing_list
from utils import get_facebook_biasing_list, brian_biasing_list, write_error_stats
from ls_text_normalization import word_normalization
from text_normalization import (
ref_text_normalization,
@ -92,7 +92,6 @@ from icefall.utils import (
setup_logger,
store_transcripts,
str2bool,
write_error_stats,
)
LOG_EPS = math.log(1e-10)
@ -334,11 +333,18 @@ def get_parser():
help="If use a fixed context list for LibriSpeech decoding"
)
parser.add_argument(
"--biasing-level",
type=str,
default="utterance",
choices=["utterance", "Book", "Chapter"],
)
parser.add_argument(
"--ls-distractors",
type=str2bool,
default=True,
help="If add distractors into context list for LibriSpeech decoding"
type=int,
default=0,
help="The number of distractors into context list for LibriSpeech decoding"
)
add_model_arguments(parser)
@ -430,13 +436,20 @@ def decode_one_batch(
pre_texts = ["" for _ in range(batch_size)]
if params.use_ls_context_list:
if params.biasing_level == "utterance":
pre_texts = [biasing_dict[id] for id in cut_ids]
elif params.biasing_level == "Chapter":
chapter_ids = [c.split('-')[1] for c in cut_ids]
pre_texts = [biasing_dict[id] for id in chapter_ids]
elif params.biasing_level == "Book":
chapter_ids = [c.split('-')[1] for c in cut_ids]
pre_texts = [biasing_dict[id] for id in chapter_ids]
if params.pre_text_transform == "mixed-punc":
pre_texts = [t.lower() for t in pre_texts]
# get style_text
if params.use_style_prompt:
fixed_sentence = "Mixed-case English transcription, with punctuation. Actually, it is fully not related."
fixed_sentence = "Mixed-case English transcription, with punctuation. Actually, it's fully not related."
style_texts = batch["supervisions"].get("style_text", [fixed_sentence for _ in range(batch_size)])
style_texts = [train_text_normalization(t) for t in style_texts]
else:
@ -447,6 +460,7 @@ def decode_one_batch(
# apply style transform to the pre_text and style_text
pre_texts = _apply_style_transform(pre_texts, params.pre_text_transform)
if not params.use_ls_context_list:
pre_texts = [t[:params.max_prompt_lens] for t in pre_texts]
#pre_texts = random_shuffle_subset(pre_texts, p=1.0, p_mask=0.0)
if params.use_style_prompt:
@ -461,7 +475,9 @@ def decode_one_batch(
style_texts=style_texts,
tokenizer=tokenizer,
device=device,
no_limit=True,
)
logging.info(f"Shape of the encoded prompts: {encoded_inputs['input_ids'].shape}")
memory, memory_key_padding_mask = model.encode_text(
encoded_inputs=encoded_inputs,
@ -666,6 +682,7 @@ def save_results(
params: AttributeDict,
test_set_name: str,
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
biasing_words: List[str] = None,
):
test_set_wers = dict()
test_set_cers = dict()
@ -792,9 +809,9 @@ def main():
params.suffix += f"-use-context-fuser"
if params.use_ls_context_list:
params.suffix += f"-use-ls-context-list"
if params.ls_distractors:
params.suffix += f"-add-ls-context-distractors"
params.suffix += f"-use-{params.biasing_level}-level-ls-context-list"
if params.biasing_level == "utterance" and params.ls_distractors:
params.suffix += f"-ls-context-distractors-{params.ls_distractors}"
if params.use_averaged_model:
params.suffix += "-use-averaged-model"
@ -919,6 +936,7 @@ def main():
ls_test_clean_cuts = libriheavy.librispeech_test_clean_cuts()
ls_test_other_cuts = libriheavy.librispeech_test_other_cuts()
long_audio_cuts = libriheavy.long_audio_cuts()
npr1_dev_cuts = libriheavy.npr1_dev_cuts()
npr1_test_cuts = libriheavy.npr1_test_cuts()

View File

@ -60,7 +60,8 @@ from beam_search import (
modified_beam_search,
)
from dataset import naive_triplet_text_sampling, random_shuffle_subset, get_substring
from utils import get_facebook_biasing_list
from utils import get_facebook_biasing_list, brian_biasing_list, write_error_stats
from ls_text_normalization import word_normalization
from text_normalization import ref_text_normalization, remove_non_alphabetic, upper_only_alpha, upper_all_char, train_text_normalization
from train_subformer_with_style import (
add_model_arguments,
@ -82,7 +83,6 @@ from icefall.utils import (
setup_logger,
store_transcripts,
str2bool,
write_error_stats,
)
LOG_EPS = math.log(1e-10)
@ -324,11 +324,18 @@ def get_parser():
help="If use a fixed context list for LibriSpeech decoding"
)
parser.add_argument(
"--biasing-level",
type=str,
default="utterance",
choices=["utterance", "Book", "Chapter"],
)
parser.add_argument(
"--ls-distractors",
type=str2bool,
default=True,
help="If add distractors into context list for LibriSpeech decoding"
type=int,
default=0,
help="The number of distractors into context list for LibriSpeech decoding"
)
add_model_arguments(parser)
@ -414,11 +421,19 @@ def decode_one_batch(
if "pre_text" in batch["supervisions"] and params.use_pre_text:
pre_texts = batch["supervisions"]["pre_text"]
pre_texts = [train_text_normalization(t) for t in pre_texts]
else:
pre_texts = ["" for _ in range(batch_size)]
if params.use_ls_context_list:
if params.biasing_level == "utterance":
pre_texts = [biasing_dict[id] for id in cut_ids]
elif params.biasing_level == "Chapter":
chapter_ids = [c.split('-')[1] for c in cut_ids]
pre_texts = [biasing_dict[id] for id in chapter_ids]
elif params.biasing_level == "Book":
chapter_ids = [c.split('-')[1] for c in cut_ids]
pre_texts = [biasing_dict[id] for id in chapter_ids]
if params.pre_text_transform == "mixed-punc":
pre_texts = [t.lower() for t in pre_texts]
@ -434,6 +449,7 @@ def decode_one_batch(
# apply style transform to the pre_text and style_text
pre_texts = _apply_style_transform(pre_texts, params.pre_text_transform)
if not params.use_ls_context_list:
pre_texts = [t[:params.max_prompt_lens] for t in pre_texts]
#pre_texts = random_shuffle_subset(pre_texts, p=1.0, p_mask=0.0)
if params.use_style_prompt:
@ -448,8 +464,9 @@ def decode_one_batch(
style_texts=style_texts,
bpe_model=text_encoder_bpe_model,
device=device,
max_tokens=1000,
max_tokens=8000,
)
logging.info(f"Shape of the encoded prompts: {pre_texts.shape}")
memory, memory_key_padding_mask = model.encode_text(
text=pre_texts,
@ -608,6 +625,13 @@ def decode_dataset(
texts = _apply_style_transform(texts, params.style_text_transform)
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
if not params.use_ls_test_set:
try:
book_names = [cut.text_path.split('/')[-2] for cut in batch["supervisions"]["cut"]]
except:
book_names = [cut.id.split('/')[0] for cut in batch["supervisions"]["cut"]]
else:
book_names = ["" for _ in cut_ids]
hyps_dict = decode_one_batch(
params=params,
@ -623,13 +647,14 @@ def decode_dataset(
for name, hyps in hyps_dict.items():
this_batch = []
assert len(hyps) == len(texts)
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
for cut_id, book_name, hyp_words, ref_text in zip(cut_ids, book_names, hyps, texts):
ref_text = ref_text_normalization(
ref_text
) # remove full-width symbols & some book marks
ref_words = ref_text.split()
this_batch.append((cut_id, ref_words, hyp_words))
# if not params.use_ls_test_set:
# results[name + "_" + book_name].extend(this_batch)
results[name].extend(this_batch)
num_cuts += len(texts)
@ -647,6 +672,7 @@ def save_results(
params: AttributeDict,
test_set_name: str,
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
biasing_words: List[str] = None,
):
test_set_wers = dict()
test_set_cers = dict()
@ -661,7 +687,7 @@ def save_results(
errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt"
with open(errs_filename, "w") as f:
wer = write_error_stats(
f, f"{test_set_name}-{key}", results, enable_log=True
f, f"{test_set_name}-{key}", results, enable_log=True, biasing_words=biasing_words,
)
test_set_wers[key] = wer
@ -769,9 +795,9 @@ def main():
params.suffix += f"-use-context-fuser"
if params.use_ls_context_list:
params.suffix += f"-use-ls-context-list"
if params.ls_distractors:
params.suffix += f"-add-ls-context-distractors"
params.suffix += f"-use-{params.biasing_level}-level-ls-context-list"
if params.biasing_level == "utterance" and params.ls_distractors:
params.suffix += f"-ls-context-distractors-{params.ls_distractors}"
if params.use_averaged_model:
params.suffix += "-use-averaged-model"
@ -921,16 +947,28 @@ def main():
test_sets = ["long-audio"]
test_dl = [long_audio_dl]
#test_sets = ["npr1-dev", "npr1-test"]
#test_dl = [npr1_dev_dl, npr1_test_dl]
if params.long_audio_recog:
test_sets = ["long-audio"]
test_dl = [long_audio_dl]
for test_set, test_dl in zip(test_sets, test_dl):
if params.biasing_level == "utterance":
if test_set == "ls-test-clean":
biasing_dict = get_facebook_biasing_list("test-clean", use_distractors=params.ls_distractors)
biasing_dict = get_facebook_biasing_list("test-clean", use_distractors=params.ls_distractors, num_distractors=params.ls_distractors)
elif test_set == "ls-test-other":
biasing_dict = get_facebook_biasing_list("test-other", use_distractors=params.ls_distractors)
biasing_dict = get_facebook_biasing_list("test-other", use_distractors=params.ls_distractors, num_distractors=params.ls_distractors)
else:
biasing_dict = None
f = open("data/context_biasing/LibriSpeechBiasingLists/all_rare_words.txt", 'r')
biasing_words = f.read().strip().split()
f.close()
else:
if params.use_ls_test_set:
biasing_dict = brian_biasing_list(params.biasing_level)
else:
biasing_dict = None
biasing_words = None
results_dict = decode_dataset(
dl=test_dl,
@ -961,7 +999,7 @@ def main():
if params.use_ls_test_set:
hyp = " ".join(hyp).replace("-", " ").split() # handle the hypens
hyp = upper_only_alpha(" ".join(hyp)).split()
hyp = [word_normalization(w.upper()) for w in hyp]
hyp = [word_normalization(str(w).upper()) for w in hyp]
hyp = " ".join(hyp).split()
hyp = [w for w in hyp if w != ""]
ref = upper_only_alpha(" ".join(ref)).split()
@ -975,6 +1013,7 @@ def main():
params=params,
test_set_name=test_set,
results_dict=new_res,
biasing_words=biasing_words,
)
if params.suffix.endswith("-post-normalization"):