From 81af525de47976b485941bea6414082c21e49b86 Mon Sep 17 00:00:00 2001 From: marcoyang1998 Date: Fri, 8 Sep 2023 10:15:21 +0800 Subject: [PATCH] update the biasing lists --- .../ASR/zipformer_prompt_asr/utils.py | 275 +++++++++++++++++- 1 file changed, 262 insertions(+), 13 deletions(-) diff --git a/egs/libriheavy/ASR/zipformer_prompt_asr/utils.py b/egs/libriheavy/ASR/zipformer_prompt_asr/utils.py index e6bbcc8ee..765c4c653 100644 --- a/egs/libriheavy/ASR/zipformer_prompt_asr/utils.py +++ b/egs/libriheavy/ASR/zipformer_prompt_asr/utils.py @@ -1,23 +1,36 @@ -from typing import Dict +from typing import Dict, List, Tuple, TextIO, Union, Iterable import ast +from collections import defaultdict from lhotse import load_manifest, load_manifest_lazy from lhotse.cut import Cut, CutSet from text_normalization import remove_non_alphabetic from tqdm import tqdm import os +import kaldialign +import logging + def get_facebook_biasing_list( test_set: str, use_distractors: bool = False, num_distractors: int = 100, ) -> Dict: - assert num_distractors in (100,500,1000,2000), num_distractors - if test_set == "test-clean": - biasing_file = f"data/context_biasing/fbai-speech/is21_deep_bias/ref/test-clean.biasing_{num_distractors}.tsv" - elif test_set == "test-other": - biasing_file = f"data/context_biasing/fbai-speech/is21_deep_bias/ref/test-other.biasing_{num_distractors}.tsv" + # Get the biasing list from the meta paper: https://arxiv.org/pdf/2104.02194.pdf + assert num_distractors in (0, 100,500,1000,2000), num_distractors + if num_distractors == 0: + if test_set == "test-clean": + biasing_file = f"data/context_biasing/fbai-speech/is21_deep_bias/ref/test-clean.biasing_100.tsv" + elif test_set == "test-other": + biasing_file = f"data/context_biasing/fbai-speech/is21_deep_bias/ref/test-other.biasing_100.tsv" + else: + raise ValueError(f"Unseen test set {test_set}") else: - raise ValueError(f"Unseen test set {test_set}") + if test_set == "test-clean": + biasing_file = f"data/context_biasing/fbai-speech/is21_deep_bias/ref/test-clean.biasing_{num_distractors}.tsv" + elif test_set == "test-other": + biasing_file = f"data/context_biasing/fbai-speech/is21_deep_bias/ref/test-other.biasing_{num_distractors}.tsv" + else: + raise ValueError(f"Unseen test set {test_set}") f = open(biasing_file, 'r') data = f.readlines() @@ -35,7 +48,28 @@ def get_facebook_biasing_list( return output +def brian_biasing_list(level: str): + # The biasing list from Brian's paper: https://arxiv.org/pdf/2109.00627.pdf + import glob + root_dir = f"data/context_biasing/LibriSpeechBiasingLists/{level}Level" + all_files = glob.glob(root_dir + "/*") + biasing_dict = {} + for f in all_files: + k = f.split('/')[-1] + fin = open(f, 'r') + data = fin.read().strip().split() + biasing_dict[k] = " ".join(data) + fin.close() + + return biasing_dict + def get_rare_words(subset: str, min_count: int): + """Get a list of rare words appearing less than `min_count` times + + Args: + subset: + min_count (int): Count of appearance + """ txt_path = f"data/tmp/transcript_words_{subset}.txt" rare_word_file = f"data/context_biasing/{subset}_rare_words_{min_count}.txt" @@ -59,12 +93,13 @@ def get_rare_words(subset: str, min_count: int): with open(count_file, 'w') as fout: for w in word_count: - fout.write(f"{w}\t{word_count[w]}") + fout.write(f"{w}\t{word_count[w]}\n") else: word_count = {} with open(count_file, 'r') as fin: - word_count = fin.read().split('\n') - word_count = [pair.split() for pair in word_count] + word_count = fin.read().strip().split('\n') + word_count = [pair.split('\t') for pair in word_count] + word_count = sorted(word_count, key=lambda w: int(w[1]), reverse=True) print(f"A total of {len(word_count)} words appeared!") rare_words = [] @@ -77,6 +112,13 @@ def get_rare_words(subset: str, min_count: int): f.writelines(rare_words) def add_context_list_to_manifest(subset: str, min_count: int): + """Generate a context list of rare words for each utterance in the manifest + + Args: + subset (str): Subset + min_count (int): The min appearances + + """ rare_words_file = f"data/context_biasing/{subset}_rare_words_{min_count}.txt" manifest_dir = f"data/fbank/libriheavy_cuts_{subset}.jsonl.gz" @@ -111,7 +153,7 @@ def add_context_list_to_manifest(subset: str, min_count: int): def check(subset: str, min_count: int): - #manifest_dir = f"data/fbank/libriheavy_cuts_{subset}_with_context_list_min_count_{min_count}.jsonl.gz" + # Used to show how many samples in the training set have a context list print("Calculating the stats over the manifest") manifest_dir = f"data/fbank/libriheavy_cuts_{subset}_with_context_list_min_count_{min_count}.jsonl.gz" cuts = load_manifest_lazy(manifest_dir) @@ -120,12 +162,219 @@ def check(subset: str, min_count: int): context_list_len = [len(c.supervisions[0].context_list.split()) for c in cuts] print(f"{sum(has_context_list)}/{total_cuts} cuts have context list! ") print(f"Average length of non-empty context list is {sum(context_list_len)/sum(has_context_list)}") + + +def write_error_stats( + f: TextIO, + test_set_name: str, + results: List[Tuple[str, str]], + enable_log: bool = True, + compute_CER: bool = False, + biasing_words: List[str] = None, +) -> float: + """Write statistics based on predicted results and reference transcripts. It also calculates the + biasing word error rate as described in https://arxiv.org/pdf/2104.02194.pdf + + It will write the following to the given file: + + - WER + - number of insertions, deletions, substitutions, corrects and total + reference words. For example:: + + Errors: 23 insertions, 57 deletions, 212 substitutions, over 2606 + reference words (2337 correct) + + - The difference between the reference transcript and predicted result. + An instance is given below:: + + THE ASSOCIATION OF (EDISON->ADDISON) ILLUMINATING COMPANIES + + The above example shows that the reference word is `EDISON`, + but it is predicted to `ADDISON` (a substitution error). + + Another example is:: + + FOR THE FIRST DAY (SIR->*) I THINK + + The reference word `SIR` is missing in the predicted + results (a deletion error). + results: + An iterable of tuples. The first element is the cut_id, the second is + the reference transcript and the third element is the predicted result. + enable_log: + If True, also print detailed WER to the console. + Otherwise, it is written only to the given file. + biasing_words: + All the words in the biasing list + Returns: + Return None. + """ + subs: Dict[Tuple[str, str], int] = defaultdict(int) + ins: Dict[str, int] = defaultdict(int) + dels: Dict[str, int] = defaultdict(int) + + # `words` stores counts per word, as follows: + # corr, ref_sub, hyp_sub, ins, dels + words: Dict[str, List[int]] = defaultdict(lambda: [0, 0, 0, 0, 0]) + num_corr = 0 + ERR = "*" + if compute_CER: + for i, res in enumerate(results): + cut_id, ref, hyp = res + ref = list("".join(ref)) + hyp = list("".join(hyp)) + results[i] = (cut_id, ref, hyp) + + for cut_id, ref, hyp in results: + ali = kaldialign.align(ref, hyp, ERR) + for ref_word, hyp_word in ali: + if ref_word == ERR: # INSERTION + ins[hyp_word] += 1 + words[hyp_word][3] += 1 + elif hyp_word == ERR: # DELETION + dels[ref_word] += 1 + words[ref_word][4] += 1 + elif hyp_word != ref_word: # SUBSTITUTION + subs[(ref_word, hyp_word)] += 1 + words[ref_word][1] += 1 + words[hyp_word][2] += 1 + else: + words[ref_word][0] += 1 + num_corr += 1 + ref_len = sum([len(r) for _, r, _ in results]) + sub_errs = sum(subs.values()) + ins_errs = sum(ins.values()) + del_errs = sum(dels.values()) + tot_errs = sub_errs + ins_errs + del_errs + tot_err_rate = "%.2f" % (100.0 * tot_errs / ref_len) + + if enable_log: + logging.info( + f"[{test_set_name}] %WER {tot_errs / ref_len:.2%} " + f"[{tot_errs} / {ref_len}, {ins_errs} ins, " + f"{del_errs} del, {sub_errs} sub ]" + ) + + print(f"%WER = {tot_err_rate}", file=f) + print( + f"Errors: {ins_errs} insertions, {del_errs} deletions, " + f"{sub_errs} substitutions, over {ref_len} reference " + f"words ({num_corr} correct)", + file=f, + ) + print( + "Search below for sections starting with PER-UTT DETAILS:, " + "SUBSTITUTIONS:, DELETIONS:, INSERTIONS:, PER-WORD STATS:", + file=f, + ) + + print("", file=f) + print("PER-UTT DETAILS: corr or (ref->hyp) ", file=f) + for cut_id, ref, hyp in results: + ali = kaldialign.align(ref, hyp, ERR) + combine_successive_errors = True + if combine_successive_errors: + ali = [[[x], [y]] for x, y in ali] + for i in range(len(ali) - 1): + if ali[i][0] != ali[i][1] and ali[i + 1][0] != ali[i + 1][1]: + ali[i + 1][0] = ali[i][0] + ali[i + 1][0] + ali[i + 1][1] = ali[i][1] + ali[i + 1][1] + ali[i] = [[], []] + ali = [ + [ + list(filter(lambda a: a != ERR, x)), + list(filter(lambda a: a != ERR, y)), + ] + for x, y in ali + ] + ali = list(filter(lambda x: x != [[], []], ali)) + ali = [ + [ + ERR if x == [] else " ".join(x), + ERR if y == [] else " ".join(y), + ] + for x, y in ali + ] + + print( + f"{cut_id}:\t" + + " ".join( + ( + ref_word + if ref_word == hyp_word + else f"({ref_word}->{hyp_word})" + for ref_word, hyp_word in ali + ) + ), + file=f, + ) + + print("", file=f) + print("SUBSTITUTIONS: count ref -> hyp", file=f) + + for count, (ref, hyp) in sorted( + [(v, k) for k, v in subs.items()], reverse=True + ): + print(f"{count} {ref} -> {hyp}", file=f) + + print("", file=f) + print("DELETIONS: count ref", file=f) + for count, ref in sorted([(v, k) for k, v in dels.items()], reverse=True): + print(f"{count} {ref}", file=f) + + print("", file=f) + print("INSERTIONS: count hyp", file=f) + for count, hyp in sorted([(v, k) for k, v in ins.items()], reverse=True): + print(f"{count} {hyp}", file=f) + + unbiased_word_counts = 0 + unbiased_word_errs = 0 + biased_word_counts = 0 + biased_word_errs = 0 + + print("", file=f) + print( + "PER-WORD STATS: word corr tot_errs count_in_ref count_in_hyp", file=f + ) + + for _, word, counts in sorted( + [(sum(v[1:]), k, v) for k, v in words.items()], reverse=True + ): + (corr, ref_sub, hyp_sub, ins, dels) = counts + tot_errs = ref_sub + hyp_sub + ins + dels + # number of appearances of "word" in reference text + ref_count = corr + ref_sub + dels # correct + in ref but got substituted + deleted + # number of appearances of "word" in hyp text + hyp_count = corr + hyp_sub + ins + + + if biasing_words is not None: + if word in biasing_words: + biased_word_counts += ref_count + biased_word_errs += (ins + dels + ref_sub) + else: + unbiased_word_counts += ref_count + unbiased_word_errs += (ins + dels + hyp_sub) + + print(f"{word} {corr} {tot_errs} {ref_count} {hyp_count}", file=f) + + if biasing_words is not None: + B_WER = "%.2f" % (100 *biased_word_errs/biased_word_counts) + U_WER = "%.2f" % (100 *unbiased_word_errs/unbiased_word_counts) + logging.info(f"Biased WER: {B_WER} [{biased_word_errs}/{biased_word_counts}] ") + logging.info(f"Un-biased WER: {U_WER} [{unbiased_word_errs}/{unbiased_word_counts}]") + + + return float(tot_err_rate) + + + if __name__=="__main__": #test_set = "test-clean" #get_facebook_biasing_list(test_set) subset = "medium" - min_count = 10 + min_count = 460 #get_rare_words(subset, min_count) - #add_context_list_to_manifest(subset=subset, min_count=min_count) + add_context_list_to_manifest(subset=subset, min_count=min_count) check(subset=subset, min_count=min_count) \ No newline at end of file