mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 18:12:19 +00:00
* Add PromptASR with BERT as text encoder * Support using word-list based content prompts for context biasing * Upload the pretrained models to huggingface * Add usage example
440 lines
14 KiB
Python
440 lines
14 KiB
Python
import argparse
|
|
import ast
|
|
import glob
|
|
import logging
|
|
import os
|
|
from collections import defaultdict
|
|
from typing import Dict, Iterable, List, TextIO, Tuple, Union
|
|
|
|
import kaldialign
|
|
from lhotse import load_manifest, load_manifest_lazy
|
|
from lhotse.cut import Cut, CutSet
|
|
from text_normalization import remove_non_alphabetic
|
|
from tqdm import tqdm
|
|
|
|
|
|
def get_parser():
|
|
parser = argparse.ArgumentParser(
|
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--manifest-dir",
|
|
type=str,
|
|
default="data/fbank",
|
|
help="Where are the manifest stored",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--subset", type=str, default="medium", help="Which subset to work with"
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--top-k",
|
|
type=int,
|
|
default=10000,
|
|
help="How many words to keep",
|
|
)
|
|
|
|
return parser
|
|
|
|
|
|
def get_facebook_biasing_list(
|
|
test_set: str,
|
|
num_distractors: int = 100,
|
|
) -> Dict:
|
|
# Get the biasing list from the meta paper: https://arxiv.org/pdf/2104.02194.pdf
|
|
assert num_distractors in (0, 100, 500, 1000, 2000), num_distractors
|
|
if num_distractors == 0:
|
|
if test_set == "test-clean":
|
|
biasing_file = "data/context_biasing/fbai-speech/is21_deep_bias/ref/test-clean.biasing_100.tsv"
|
|
elif test_set == "test-other":
|
|
biasing_file = "data/context_biasing/fbai-speech/is21_deep_bias/ref/test-other.biasing_100.tsv"
|
|
else:
|
|
raise ValueError(f"Unseen test set {test_set}")
|
|
else:
|
|
if test_set == "test-clean":
|
|
biasing_file = f"data/context_biasing/fbai-speech/is21_deep_bias/ref/test-clean.biasing_{num_distractors}.tsv"
|
|
elif test_set == "test-other":
|
|
biasing_file = f"data/context_biasing/fbai-speech/is21_deep_bias/ref/test-other.biasing_{num_distractors}.tsv"
|
|
else:
|
|
raise ValueError(f"Unseen test set {test_set}")
|
|
|
|
f = open(biasing_file, "r")
|
|
data = f.readlines()
|
|
f.close()
|
|
|
|
output = dict()
|
|
for line in data:
|
|
id, _, l1, l2 = line.split("\t")
|
|
if num_distractors > 0: # use distractors
|
|
biasing_list = ast.literal_eval(l2)
|
|
else:
|
|
biasing_list = ast.literal_eval(l1)
|
|
biasing_list = [w.strip().upper() for w in biasing_list]
|
|
output[id] = " ".join(biasing_list)
|
|
|
|
return output
|
|
|
|
|
|
def brian_biasing_list(level: str):
|
|
# The biasing list from Brian's paper: https://arxiv.org/pdf/2109.00627.pdf
|
|
root_dir = f"data/context_biasing/LibriSpeechBiasingLists/{level}Level"
|
|
all_files = glob.glob(root_dir + "/*")
|
|
biasing_dict = {}
|
|
for f in all_files:
|
|
k = f.split("/")[-1]
|
|
fin = open(f, "r")
|
|
data = fin.read().strip().split()
|
|
biasing_dict[k] = " ".join(data)
|
|
fin.close()
|
|
|
|
return biasing_dict
|
|
|
|
|
|
def get_rare_words(
|
|
subset: str = "medium",
|
|
top_k: int = 10000,
|
|
# min_count: int = 10000,
|
|
):
|
|
"""Get a list of rare words appearing less than `min_count` times
|
|
|
|
Args:
|
|
subset: The dataset
|
|
top_k (int): How many frequent words
|
|
"""
|
|
txt_path = f"data/tmp/transcript_words_{subset}.txt"
|
|
rare_word_file = f"data/context_biasing/{subset}_rare_words_topk_{top_k}.txt"
|
|
|
|
if os.path.exists(rare_word_file):
|
|
print("File exists, do not proceed!")
|
|
return
|
|
|
|
print("---Identifying rare words in the manifest---")
|
|
count_file = f"data/tmp/transcript_words_{subset}_count.txt"
|
|
if not os.path.exists(count_file):
|
|
with open(txt_path, "r") as file:
|
|
words = file.read().upper().split()
|
|
word_count = {}
|
|
for word in words:
|
|
word = remove_non_alphabetic(word, strict=False)
|
|
word = word.split()
|
|
for w in word:
|
|
if w not in word_count:
|
|
word_count[w] = 1
|
|
else:
|
|
word_count[w] += 1
|
|
|
|
word_count = list(word_count.items()) # convert to a list of tuple
|
|
word_count = sorted(word_count, key=lambda w: int(w[1]), reverse=True)
|
|
with open(count_file, "w") as fout:
|
|
for w, count in word_count:
|
|
fout.write(f"{w}\t{count}\n")
|
|
|
|
else:
|
|
word_count = {}
|
|
with open(count_file, "r") as fin:
|
|
word_count = fin.read().strip().split("\n")
|
|
word_count = [pair.split("\t") for pair in word_count]
|
|
word_count = sorted(word_count, key=lambda w: int(w[1]), reverse=True)
|
|
|
|
print(f"A total of {len(word_count)} words appeared!")
|
|
rare_words = []
|
|
for word, count in word_count[top_k:]:
|
|
rare_words.append(word + "\n")
|
|
print(f"A total of {len(rare_words)} are identified as rare words.")
|
|
|
|
with open(rare_word_file, "w") as f:
|
|
f.writelines(rare_words)
|
|
|
|
|
|
def add_context_list_to_manifest(
|
|
manifest_dir: str,
|
|
subset: str = "medium",
|
|
top_k: int = 10000,
|
|
):
|
|
"""Generate a context list of rare words for each utterance in the manifest
|
|
|
|
Args:
|
|
manifest_dir: Where to store the manifest with context list
|
|
subset (str): Subset
|
|
top_k (int): How many frequent words
|
|
|
|
"""
|
|
orig_manifest_dir = f"{manifest_dir}/libriheavy_cuts_{subset}.jsonl.gz"
|
|
target_manifest_dir = orig_manifest_dir.replace(
|
|
".jsonl.gz", f"_with_context_list_topk_{top_k}.jsonl.gz"
|
|
)
|
|
if os.path.exists(target_manifest_dir):
|
|
print(f"Target file exits at {target_manifest_dir}!")
|
|
return
|
|
|
|
rare_words_file = f"data/context_biasing/{subset}_rare_words_topk_{top_k}.txt"
|
|
print(f"---Reading rare words from {rare_words_file}---")
|
|
with open(rare_words_file, "r") as f:
|
|
rare_words = f.read()
|
|
rare_words = rare_words.split("\n")
|
|
rare_words = set(rare_words)
|
|
print(f"A total of {len(rare_words)} rare words!")
|
|
|
|
cuts = load_manifest_lazy(orig_manifest_dir)
|
|
print(f"Loaded manifest from {orig_manifest_dir}")
|
|
|
|
def _add_context(c: Cut):
|
|
splits = (
|
|
remove_non_alphabetic(c.supervisions[0].texts[0], strict=False)
|
|
.upper()
|
|
.split()
|
|
)
|
|
found = []
|
|
for w in splits:
|
|
if w in rare_words:
|
|
found.append(w)
|
|
c.supervisions[0].context_list = " ".join(found)
|
|
return c
|
|
|
|
cuts = cuts.map(_add_context)
|
|
print(f"---Saving manifest with context list to {target_manifest_dir}---")
|
|
cuts.to_file(target_manifest_dir)
|
|
print("Finished")
|
|
|
|
|
|
def check(
|
|
manifest_dir: str,
|
|
subset: str = "medium",
|
|
top_k: int = 10000,
|
|
):
|
|
# Show how many samples in the training set have a context list
|
|
# and the average length of context list
|
|
print("--- Calculating the stats over the manifest ---")
|
|
|
|
manifest_dir = f"{manifest_dir}/libriheavy_cuts_{subset}_with_context_list_topk_{top_k}.jsonl.gz"
|
|
cuts = load_manifest_lazy(manifest_dir)
|
|
total_cuts = len(cuts)
|
|
has_context_list = [c.supervisions[0].context_list != "" for c in cuts]
|
|
context_list_len = [len(c.supervisions[0].context_list.split()) for c in cuts]
|
|
print(f"{sum(has_context_list)}/{total_cuts} cuts have context list! ")
|
|
print(
|
|
f"Average length of non-empty context list is {sum(context_list_len)/sum(has_context_list)}"
|
|
)
|
|
|
|
|
|
def write_error_stats(
|
|
f: TextIO,
|
|
test_set_name: str,
|
|
results: List[Tuple[str, str]],
|
|
enable_log: bool = True,
|
|
compute_CER: bool = False,
|
|
biasing_words: List[str] = None,
|
|
) -> float:
|
|
"""Write statistics based on predicted results and reference transcripts. It also calculates the
|
|
biasing word error rate as described in https://arxiv.org/pdf/2104.02194.pdf
|
|
|
|
It will write the following to the given file:
|
|
|
|
- WER
|
|
- number of insertions, deletions, substitutions, corrects and total
|
|
reference words. For example::
|
|
|
|
Errors: 23 insertions, 57 deletions, 212 substitutions, over 2606
|
|
reference words (2337 correct)
|
|
|
|
- The difference between the reference transcript and predicted result.
|
|
An instance is given below::
|
|
|
|
THE ASSOCIATION OF (EDISON->ADDISON) ILLUMINATING COMPANIES
|
|
|
|
The above example shows that the reference word is `EDISON`,
|
|
but it is predicted to `ADDISON` (a substitution error).
|
|
|
|
Another example is::
|
|
|
|
FOR THE FIRST DAY (SIR->*) I THINK
|
|
|
|
The reference word `SIR` is missing in the predicted
|
|
results (a deletion error).
|
|
results:
|
|
An iterable of tuples. The first element is the cut_id, the second is
|
|
the reference transcript and the third element is the predicted result.
|
|
enable_log:
|
|
If True, also print detailed WER to the console.
|
|
Otherwise, it is written only to the given file.
|
|
biasing_words:
|
|
All the words in the biasing list
|
|
Returns:
|
|
Return None.
|
|
"""
|
|
subs: Dict[Tuple[str, str], int] = defaultdict(int)
|
|
ins: Dict[str, int] = defaultdict(int)
|
|
dels: Dict[str, int] = defaultdict(int)
|
|
|
|
# `words` stores counts per word, as follows:
|
|
# corr, ref_sub, hyp_sub, ins, dels
|
|
words: Dict[str, List[int]] = defaultdict(lambda: [0, 0, 0, 0, 0])
|
|
num_corr = 0
|
|
ERR = "*"
|
|
|
|
if compute_CER:
|
|
for i, res in enumerate(results):
|
|
cut_id, ref, hyp = res
|
|
ref = list("".join(ref))
|
|
hyp = list("".join(hyp))
|
|
results[i] = (cut_id, ref, hyp)
|
|
|
|
for cut_id, ref, hyp in results:
|
|
ali = kaldialign.align(ref, hyp, ERR)
|
|
for ref_word, hyp_word in ali:
|
|
if ref_word == ERR: # INSERTION
|
|
ins[hyp_word] += 1
|
|
words[hyp_word][3] += 1
|
|
elif hyp_word == ERR: # DELETION
|
|
dels[ref_word] += 1
|
|
words[ref_word][4] += 1
|
|
elif hyp_word != ref_word: # SUBSTITUTION
|
|
subs[(ref_word, hyp_word)] += 1
|
|
words[ref_word][1] += 1
|
|
words[hyp_word][2] += 1
|
|
else:
|
|
words[ref_word][0] += 1
|
|
num_corr += 1
|
|
ref_len = sum([len(r) for _, r, _ in results])
|
|
sub_errs = sum(subs.values())
|
|
ins_errs = sum(ins.values())
|
|
del_errs = sum(dels.values())
|
|
tot_errs = sub_errs + ins_errs + del_errs
|
|
tot_err_rate = "%.2f" % (100.0 * tot_errs / ref_len)
|
|
|
|
if enable_log:
|
|
logging.info(
|
|
f"[{test_set_name}] %WER {tot_errs / ref_len:.2%} "
|
|
f"[{tot_errs} / {ref_len}, {ins_errs} ins, "
|
|
f"{del_errs} del, {sub_errs} sub ]"
|
|
)
|
|
|
|
print(f"%WER = {tot_err_rate}", file=f)
|
|
print(
|
|
f"Errors: {ins_errs} insertions, {del_errs} deletions, "
|
|
f"{sub_errs} substitutions, over {ref_len} reference "
|
|
f"words ({num_corr} correct)",
|
|
file=f,
|
|
)
|
|
print(
|
|
"Search below for sections starting with PER-UTT DETAILS:, "
|
|
"SUBSTITUTIONS:, DELETIONS:, INSERTIONS:, PER-WORD STATS:",
|
|
file=f,
|
|
)
|
|
|
|
print("", file=f)
|
|
print("PER-UTT DETAILS: corr or (ref->hyp) ", file=f)
|
|
for cut_id, ref, hyp in results:
|
|
ali = kaldialign.align(ref, hyp, ERR)
|
|
combine_successive_errors = True
|
|
if combine_successive_errors:
|
|
ali = [[[x], [y]] for x, y in ali]
|
|
for i in range(len(ali) - 1):
|
|
if ali[i][0] != ali[i][1] and ali[i + 1][0] != ali[i + 1][1]:
|
|
ali[i + 1][0] = ali[i][0] + ali[i + 1][0]
|
|
ali[i + 1][1] = ali[i][1] + ali[i + 1][1]
|
|
ali[i] = [[], []]
|
|
ali = [
|
|
[
|
|
list(filter(lambda a: a != ERR, x)),
|
|
list(filter(lambda a: a != ERR, y)),
|
|
]
|
|
for x, y in ali
|
|
]
|
|
ali = list(filter(lambda x: x != [[], []], ali))
|
|
ali = [
|
|
[
|
|
ERR if x == [] else " ".join(x),
|
|
ERR if y == [] else " ".join(y),
|
|
]
|
|
for x, y in ali
|
|
]
|
|
|
|
print(
|
|
f"{cut_id}:\t"
|
|
+ " ".join(
|
|
(
|
|
ref_word if ref_word == hyp_word else f"({ref_word}->{hyp_word})"
|
|
for ref_word, hyp_word in ali
|
|
)
|
|
),
|
|
file=f,
|
|
)
|
|
|
|
print("", file=f)
|
|
print("SUBSTITUTIONS: count ref -> hyp", file=f)
|
|
|
|
for count, (ref, hyp) in sorted([(v, k) for k, v in subs.items()], reverse=True):
|
|
print(f"{count} {ref} -> {hyp}", file=f)
|
|
|
|
print("", file=f)
|
|
print("DELETIONS: count ref", file=f)
|
|
for count, ref in sorted([(v, k) for k, v in dels.items()], reverse=True):
|
|
print(f"{count} {ref}", file=f)
|
|
|
|
print("", file=f)
|
|
print("INSERTIONS: count hyp", file=f)
|
|
for count, hyp in sorted([(v, k) for k, v in ins.items()], reverse=True):
|
|
print(f"{count} {hyp}", file=f)
|
|
|
|
unbiased_word_counts = 0
|
|
unbiased_word_errs = 0
|
|
biased_word_counts = 0
|
|
biased_word_errs = 0
|
|
|
|
print("", file=f)
|
|
print("PER-WORD STATS: word corr tot_errs count_in_ref count_in_hyp", file=f)
|
|
|
|
for _, word, counts in sorted(
|
|
[(sum(v[1:]), k, v) for k, v in words.items()], reverse=True
|
|
):
|
|
(corr, ref_sub, hyp_sub, ins, dels) = counts
|
|
tot_errs = ref_sub + hyp_sub + ins + dels
|
|
# number of appearances of "word" in reference text
|
|
ref_count = (
|
|
corr + ref_sub + dels
|
|
) # correct + in ref but got substituted + deleted
|
|
# number of appearances of "word" in hyp text
|
|
hyp_count = corr + hyp_sub + ins
|
|
|
|
if biasing_words is not None:
|
|
if word in biasing_words:
|
|
biased_word_counts += ref_count
|
|
biased_word_errs += ins + dels + ref_sub
|
|
else:
|
|
unbiased_word_counts += ref_count
|
|
unbiased_word_errs += ins + dels + hyp_sub
|
|
|
|
print(f"{word} {corr} {tot_errs} {ref_count} {hyp_count}", file=f)
|
|
|
|
if biasing_words is not None:
|
|
B_WER = "%.2f" % (100 * biased_word_errs / biased_word_counts)
|
|
U_WER = "%.2f" % (100 * unbiased_word_errs / unbiased_word_counts)
|
|
logging.info(f"Biased WER: {B_WER} [{biased_word_errs}/{biased_word_counts}] ")
|
|
logging.info(
|
|
f"Un-biased WER: {U_WER} [{unbiased_word_errs}/{unbiased_word_counts}]"
|
|
)
|
|
|
|
return float(tot_err_rate)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = get_parser()
|
|
args = parser.parse_args()
|
|
manifest_dir = args.manifest_dir
|
|
subset = args.subset
|
|
top_k = args.top_k
|
|
get_rare_words(subset=subset, top_k=top_k)
|
|
add_context_list_to_manifest(
|
|
manifest_dir=manifest_dir,
|
|
subset=subset,
|
|
top_k=top_k,
|
|
)
|
|
check(
|
|
manifest_dir=manifest_dir,
|
|
subset=subset,
|
|
top_k=top_k,
|
|
)
|