This commit is contained in:
marcoyang1998 2023-09-15 16:08:27 +08:00
parent a0fe6bcd0d
commit d411ffb4b6
2 changed files with 76 additions and 2173 deletions

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1 @@
../../../librispeech/ASR/pruned_transducer_stateless2/beam_search.py

View File

@ -1,14 +1,16 @@
from typing import Dict, List, Tuple, TextIO, Union, Iterable
import ast import ast
import glob
import logging
import os
from collections import defaultdict from collections import defaultdict
from typing import Dict, Iterable, List, TextIO, Tuple, Union
import kaldialign
from lhotse import load_manifest, load_manifest_lazy from lhotse import load_manifest, load_manifest_lazy
from lhotse.cut import Cut, CutSet from lhotse.cut import Cut, CutSet
from text_normalization import remove_non_alphabetic from text_normalization import remove_non_alphabetic
from tqdm import tqdm from tqdm import tqdm
import os
import kaldialign
import logging
def get_facebook_biasing_list( def get_facebook_biasing_list(
test_set: str, test_set: str,
@ -16,67 +18,68 @@ def get_facebook_biasing_list(
num_distractors: int = 100, num_distractors: int = 100,
) -> Dict: ) -> Dict:
# Get the biasing list from the meta paper: https://arxiv.org/pdf/2104.02194.pdf # Get the biasing list from the meta paper: https://arxiv.org/pdf/2104.02194.pdf
assert num_distractors in (0, 100,500,1000,2000), num_distractors assert num_distractors in (0, 100, 500, 1000, 2000), num_distractors
if num_distractors == 0: if num_distractors == 0:
if test_set == "test-clean": if test_set == "test-clean":
biasing_file = f"data/context_biasing/fbai-speech/is21_deep_bias/ref/test-clean.biasing_100.tsv" biasing_file = "data/context_biasing/fbai-speech/is21_deep_bias/ref/test-clean.biasing_100.tsv"
elif test_set == "test-other": elif test_set == "test-other":
biasing_file = f"data/context_biasing/fbai-speech/is21_deep_bias/ref/test-other.biasing_100.tsv" biasing_file = "data/context_biasing/fbai-speech/is21_deep_bias/ref/test-other.biasing_100.tsv"
else: else:
raise ValueError(f"Unseen test set {test_set}") raise ValueError(f"Unseen test set {test_set}")
else: else:
if test_set == "test-clean": if test_set == "test-clean":
biasing_file = f"data/context_biasing/fbai-speech/is21_deep_bias/ref/test-clean.biasing_{num_distractors}.tsv" biasing_file = "data/context_biasing/fbai-speech/is21_deep_bias/ref/test-clean.biasing_{num_distractors}.tsv"
elif test_set == "test-other": elif test_set == "test-other":
biasing_file = f"data/context_biasing/fbai-speech/is21_deep_bias/ref/test-other.biasing_{num_distractors}.tsv" biasing_file = "data/context_biasing/fbai-speech/is21_deep_bias/ref/test-other.biasing_{num_distractors}.tsv"
else: else:
raise ValueError(f"Unseen test set {test_set}") raise ValueError(f"Unseen test set {test_set}")
f = open(biasing_file, 'r') f = open(biasing_file, "r")
data = f.readlines() data = f.readlines()
f.close() f.close()
output = dict() output = dict()
for line in data: for line in data:
id, _, l1, l2 = line.split('\t') id, _, l1, l2 = line.split("\t")
if use_distractors: if use_distractors:
biasing_list = ast.literal_eval(l2) biasing_list = ast.literal_eval(l2)
else: else:
biasing_list = ast.literal_eval(l1) biasing_list = ast.literal_eval(l1)
biasing_list = [w.strip().upper() for w in biasing_list] biasing_list = [w.strip().upper() for w in biasing_list]
output[id] = " ".join(biasing_list) output[id] = " ".join(biasing_list)
return output return output
def brian_biasing_list(level: str): def brian_biasing_list(level: str):
# The biasing list from Brian's paper: https://arxiv.org/pdf/2109.00627.pdf # The biasing list from Brian's paper: https://arxiv.org/pdf/2109.00627.pdf
import glob
root_dir = f"data/context_biasing/LibriSpeechBiasingLists/{level}Level" root_dir = f"data/context_biasing/LibriSpeechBiasingLists/{level}Level"
all_files = glob.glob(root_dir + "/*") all_files = glob.glob(root_dir + "/*")
biasing_dict = {} biasing_dict = {}
for f in all_files: for f in all_files:
k = f.split('/')[-1] k = f.split("/")[-1]
fin = open(f, 'r') fin = open(f, "r")
data = fin.read().strip().split() data = fin.read().strip().split()
biasing_dict[k] = " ".join(data) biasing_dict[k] = " ".join(data)
fin.close() fin.close()
return biasing_dict return biasing_dict
def get_rare_words(subset: str, min_count: int): def get_rare_words(subset: str, min_count: int):
"""Get a list of rare words appearing less than `min_count` times """Get a list of rare words appearing less than `min_count` times
Args: Args:
subset: subset: The dataset
min_count (int): Count of appearance min_count (int): Count of appearance
""" """
txt_path = f"data/tmp/transcript_words_{subset}.txt" txt_path = f"data/tmp/transcript_words_{subset}.txt"
rare_word_file = f"data/context_biasing/{subset}_rare_words_{min_count}.txt" rare_word_file = f"data/context_biasing/{subset}_rare_words_{min_count}.txt"
if os.path.exists(rare_word_file): if os.path.exists(rare_word_file):
print("File exists, do not proceed!") print("File exists, do not proceed!")
return return
print(f"Finding rare words in the manifest.") print("Finding rare words in the manifest.")
count_file = f"data/tmp/transcript_words_{subset}_count.txt" count_file = f"data/tmp/transcript_words_{subset}_count.txt"
if not os.path.exists(count_file): if not os.path.exists(count_file):
with open(txt_path, "r") as file: with open(txt_path, "r") as file:
@ -90,27 +93,28 @@ def get_rare_words(subset: str, min_count: int):
word_count[w] = 1 word_count[w] = 1
else: else:
word_count[w] += 1 word_count[w] += 1
with open(count_file, 'w') as fout: with open(count_file, "w") as fout:
for w in word_count: for w in word_count:
fout.write(f"{w}\t{word_count[w]}\n") fout.write(f"{w}\t{word_count[w]}\n")
else: else:
word_count = {} word_count = {}
with open(count_file, 'r') as fin: with open(count_file, "r") as fin:
word_count = fin.read().strip().split('\n') word_count = fin.read().strip().split("\n")
word_count = [pair.split('\t') for pair in word_count] word_count = [pair.split("\t") for pair in word_count]
word_count = sorted(word_count, key=lambda w: int(w[1]), reverse=True) word_count = sorted(word_count, key=lambda w: int(w[1]), reverse=True)
print(f"A total of {len(word_count)} words appeared!") print(f"A total of {len(word_count)} words appeared!")
rare_words = [] rare_words = []
for k in word_count: for k in word_count:
if int(word_count[k]) <= min_count: if int(word_count[k]) <= min_count:
rare_words.append(k+"\n") rare_words.append(k + "\n")
print(f"A total of {len(rare_words)} appeared <= {min_count} times") print(f"A total of {len(rare_words)} appeared <= {min_count} times")
with open(rare_word_file, 'w') as f: with open(rare_word_file, "w") as f:
f.writelines(rare_words) f.writelines(rare_words)
def add_context_list_to_manifest(subset: str, min_count: int): def add_context_list_to_manifest(subset: str, min_count: int):
"""Generate a context list of rare words for each utterance in the manifest """Generate a context list of rare words for each utterance in the manifest
@ -121,24 +125,30 @@ def add_context_list_to_manifest(subset: str, min_count: int):
""" """
rare_words_file = f"data/context_biasing/{subset}_rare_words_{min_count}.txt" rare_words_file = f"data/context_biasing/{subset}_rare_words_{min_count}.txt"
manifest_dir = f"data/fbank/libriheavy_cuts_{subset}.jsonl.gz" manifest_dir = f"data/fbank/libriheavy_cuts_{subset}.jsonl.gz"
target_manifest_dir = manifest_dir.replace(".jsonl.gz", f"_with_context_list_min_count_{min_count}.jsonl.gz") target_manifest_dir = manifest_dir.replace(
".jsonl.gz", f"_with_context_list_min_count_{min_count}.jsonl.gz"
)
if os.path.exists(target_manifest_dir): if os.path.exists(target_manifest_dir):
print(f"Target file exits at {target_manifest_dir}!") print(f"Target file exits at {target_manifest_dir}!")
return return
print(f"Reading rare words from {rare_words_file}") print(f"Reading rare words from {rare_words_file}")
with open(rare_words_file, "r") as f: with open(rare_words_file, "r") as f:
rare_words = f.read() rare_words = f.read()
rare_words = rare_words.split("\n") rare_words = rare_words.split("\n")
rare_words = set(rare_words) rare_words = set(rare_words)
print(f"A total of {len(rare_words)} rare words!") print(f"A total of {len(rare_words)} rare words!")
cuts = load_manifest_lazy(manifest_dir) cuts = load_manifest_lazy(manifest_dir)
print(f"Loaded manifest from {manifest_dir}") print(f"Loaded manifest from {manifest_dir}")
def _add_context(c: Cut): def _add_context(c: Cut):
splits = remove_non_alphabetic(c.supervisions[0].texts[0], strict=False).upper().split() splits = (
remove_non_alphabetic(c.supervisions[0].texts[0], strict=False)
.upper()
.split()
)
found = [] found = []
for w in splits: for w in splits:
if w in rare_words: if w in rare_words:
@ -147,7 +157,7 @@ def add_context_list_to_manifest(subset: str, min_count: int):
return c return c
cuts = cuts.map(_add_context) cuts = cuts.map(_add_context)
cuts.to_file(target_manifest_dir) cuts.to_file(target_manifest_dir)
print(f"Saved manifest with context list to {target_manifest_dir}") print(f"Saved manifest with context list to {target_manifest_dir}")
@ -161,7 +171,9 @@ def check(subset: str, min_count: int):
has_context_list = [c.supervisions[0].context_list != "" for c in cuts] has_context_list = [c.supervisions[0].context_list != "" for c in cuts]
context_list_len = [len(c.supervisions[0].context_list.split()) for c in cuts] context_list_len = [len(c.supervisions[0].context_list.split()) for c in cuts]
print(f"{sum(has_context_list)}/{total_cuts} cuts have context list! ") print(f"{sum(has_context_list)}/{total_cuts} cuts have context list! ")
print(f"Average length of non-empty context list is {sum(context_list_len)/sum(has_context_list)}") print(
f"Average length of non-empty context list is {sum(context_list_len)/sum(has_context_list)}"
)
def write_error_stats( def write_error_stats(
@ -218,24 +230,24 @@ def write_error_stats(
words: Dict[str, List[int]] = defaultdict(lambda: [0, 0, 0, 0, 0]) words: Dict[str, List[int]] = defaultdict(lambda: [0, 0, 0, 0, 0])
num_corr = 0 num_corr = 0
ERR = "*" ERR = "*"
if compute_CER: if compute_CER:
for i, res in enumerate(results): for i, res in enumerate(results):
cut_id, ref, hyp = res cut_id, ref, hyp = res
ref = list("".join(ref)) ref = list("".join(ref))
hyp = list("".join(hyp)) hyp = list("".join(hyp))
results[i] = (cut_id, ref, hyp) results[i] = (cut_id, ref, hyp)
for cut_id, ref, hyp in results: for cut_id, ref, hyp in results:
ali = kaldialign.align(ref, hyp, ERR) ali = kaldialign.align(ref, hyp, ERR)
for ref_word, hyp_word in ali: for ref_word, hyp_word in ali:
if ref_word == ERR: # INSERTION if ref_word == ERR: # INSERTION
ins[hyp_word] += 1 ins[hyp_word] += 1
words[hyp_word][3] += 1 words[hyp_word][3] += 1
elif hyp_word == ERR: # DELETION elif hyp_word == ERR: # DELETION
dels[ref_word] += 1 dels[ref_word] += 1
words[ref_word][4] += 1 words[ref_word][4] += 1
elif hyp_word != ref_word: # SUBSTITUTION elif hyp_word != ref_word: # SUBSTITUTION
subs[(ref_word, hyp_word)] += 1 subs[(ref_word, hyp_word)] += 1
words[ref_word][1] += 1 words[ref_word][1] += 1
words[hyp_word][2] += 1 words[hyp_word][2] += 1
@ -301,9 +313,7 @@ def write_error_stats(
f"{cut_id}:\t" f"{cut_id}:\t"
+ " ".join( + " ".join(
( (
ref_word ref_word if ref_word == hyp_word else f"({ref_word}->{hyp_word})"
if ref_word == hyp_word
else f"({ref_word}->{hyp_word})"
for ref_word, hyp_word in ali for ref_word, hyp_word in ali
) )
), ),
@ -313,9 +323,7 @@ def write_error_stats(
print("", file=f) print("", file=f)
print("SUBSTITUTIONS: count ref -> hyp", file=f) print("SUBSTITUTIONS: count ref -> hyp", file=f)
for count, (ref, hyp) in sorted( for count, (ref, hyp) in sorted([(v, k) for k, v in subs.items()], reverse=True):
[(v, k) for k, v in subs.items()], reverse=True
):
print(f"{count} {ref} -> {hyp}", file=f) print(f"{count} {ref} -> {hyp}", file=f)
print("", file=f) print("", file=f)
@ -332,11 +340,9 @@ def write_error_stats(
unbiased_word_errs = 0 unbiased_word_errs = 0
biased_word_counts = 0 biased_word_counts = 0
biased_word_errs = 0 biased_word_errs = 0
print("", file=f) print("", file=f)
print( print("PER-WORD STATS: word corr tot_errs count_in_ref count_in_hyp", file=f)
"PER-WORD STATS: word corr tot_errs count_in_ref count_in_hyp", file=f
)
for _, word, counts in sorted( for _, word, counts in sorted(
[(sum(v[1:]), k, v) for k, v in words.items()], reverse=True [(sum(v[1:]), k, v) for k, v in words.items()], reverse=True
@ -344,37 +350,36 @@ def write_error_stats(
(corr, ref_sub, hyp_sub, ins, dels) = counts (corr, ref_sub, hyp_sub, ins, dels) = counts
tot_errs = ref_sub + hyp_sub + ins + dels tot_errs = ref_sub + hyp_sub + ins + dels
# number of appearances of "word" in reference text # number of appearances of "word" in reference text
ref_count = corr + ref_sub + dels # correct + in ref but got substituted + deleted ref_count = (
corr + ref_sub + dels
) # correct + in ref but got substituted + deleted
# number of appearances of "word" in hyp text # number of appearances of "word" in hyp text
hyp_count = corr + hyp_sub + ins hyp_count = corr + hyp_sub + ins
if biasing_words is not None: if biasing_words is not None:
if word in biasing_words: if word in biasing_words:
biased_word_counts += ref_count biased_word_counts += ref_count
biased_word_errs += (ins + dels + ref_sub) biased_word_errs += ins + dels + ref_sub
else: else:
unbiased_word_counts += ref_count unbiased_word_counts += ref_count
unbiased_word_errs += (ins + dels + hyp_sub) unbiased_word_errs += ins + dels + hyp_sub
print(f"{word} {corr} {tot_errs} {ref_count} {hyp_count}", file=f) print(f"{word} {corr} {tot_errs} {ref_count} {hyp_count}", file=f)
if biasing_words is not None: if biasing_words is not None:
B_WER = "%.2f" % (100 *biased_word_errs/biased_word_counts) B_WER = "%.2f" % (100 * biased_word_errs / biased_word_counts)
U_WER = "%.2f" % (100 *unbiased_word_errs/unbiased_word_counts) U_WER = "%.2f" % (100 * unbiased_word_errs / unbiased_word_counts)
logging.info(f"Biased WER: {B_WER} [{biased_word_errs}/{biased_word_counts}] ") logging.info(f"Biased WER: {B_WER} [{biased_word_errs}/{biased_word_counts}] ")
logging.info(f"Un-biased WER: {U_WER} [{unbiased_word_errs}/{unbiased_word_counts}]") logging.info(
f"Un-biased WER: {U_WER} [{unbiased_word_errs}/{unbiased_word_counts}]"
)
return float(tot_err_rate) return float(tot_err_rate)
if __name__ == "__main__":
if __name__=="__main__":
#test_set = "test-clean"
#get_facebook_biasing_list(test_set)
subset = "medium" subset = "medium"
min_count = 460 min_count = 10
#get_rare_words(subset, min_count) get_rare_words(subset, min_count)
add_context_list_to_manifest(subset=subset, min_count=min_count) add_context_list_to_manifest(subset=subset, min_count=min_count)
check(subset=subset, min_count=min_count) check(subset=subset, min_count=min_count)