This commit is contained in:
marcoyang1998 2023-09-15 16:08:27 +08:00
parent a0fe6bcd0d
commit d411ffb4b6
2 changed files with 76 additions and 2173 deletions

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1 @@
../../../librispeech/ASR/pruned_transducer_stateless2/beam_search.py

View File

@ -1,14 +1,16 @@
from typing import Dict, List, Tuple, TextIO, Union, Iterable
import ast
import glob
import logging
import os
from collections import defaultdict
from typing import Dict, Iterable, List, TextIO, Tuple, Union
import kaldialign
from lhotse import load_manifest, load_manifest_lazy
from lhotse.cut import Cut, CutSet
from text_normalization import remove_non_alphabetic
from tqdm import tqdm
import os
import kaldialign
import logging
def get_facebook_biasing_list(
test_set: str,
@ -16,67 +18,68 @@ def get_facebook_biasing_list(
num_distractors: int = 100,
) -> Dict:
# Get the biasing list from the meta paper: https://arxiv.org/pdf/2104.02194.pdf
assert num_distractors in (0, 100,500,1000,2000), num_distractors
assert num_distractors in (0, 100, 500, 1000, 2000), num_distractors
if num_distractors == 0:
if test_set == "test-clean":
biasing_file = f"data/context_biasing/fbai-speech/is21_deep_bias/ref/test-clean.biasing_100.tsv"
biasing_file = "data/context_biasing/fbai-speech/is21_deep_bias/ref/test-clean.biasing_100.tsv"
elif test_set == "test-other":
biasing_file = f"data/context_biasing/fbai-speech/is21_deep_bias/ref/test-other.biasing_100.tsv"
biasing_file = "data/context_biasing/fbai-speech/is21_deep_bias/ref/test-other.biasing_100.tsv"
else:
raise ValueError(f"Unseen test set {test_set}")
else:
if test_set == "test-clean":
biasing_file = f"data/context_biasing/fbai-speech/is21_deep_bias/ref/test-clean.biasing_{num_distractors}.tsv"
biasing_file = "data/context_biasing/fbai-speech/is21_deep_bias/ref/test-clean.biasing_{num_distractors}.tsv"
elif test_set == "test-other":
biasing_file = f"data/context_biasing/fbai-speech/is21_deep_bias/ref/test-other.biasing_{num_distractors}.tsv"
biasing_file = "data/context_biasing/fbai-speech/is21_deep_bias/ref/test-other.biasing_{num_distractors}.tsv"
else:
raise ValueError(f"Unseen test set {test_set}")
f = open(biasing_file, 'r')
f = open(biasing_file, "r")
data = f.readlines()
f.close()
output = dict()
for line in data:
id, _, l1, l2 = line.split('\t')
id, _, l1, l2 = line.split("\t")
if use_distractors:
biasing_list = ast.literal_eval(l2)
else:
biasing_list = ast.literal_eval(l1)
biasing_list = [w.strip().upper() for w in biasing_list]
biasing_list = [w.strip().upper() for w in biasing_list]
output[id] = " ".join(biasing_list)
return output
def brian_biasing_list(level: str):
# The biasing list from Brian's paper: https://arxiv.org/pdf/2109.00627.pdf
import glob
root_dir = f"data/context_biasing/LibriSpeechBiasingLists/{level}Level"
all_files = glob.glob(root_dir + "/*")
biasing_dict = {}
for f in all_files:
k = f.split('/')[-1]
fin = open(f, 'r')
k = f.split("/")[-1]
fin = open(f, "r")
data = fin.read().strip().split()
biasing_dict[k] = " ".join(data)
fin.close()
return biasing_dict
def get_rare_words(subset: str, min_count: int):
"""Get a list of rare words appearing less than `min_count` times
Args:
subset:
subset: The dataset
min_count (int): Count of appearance
"""
txt_path = f"data/tmp/transcript_words_{subset}.txt"
rare_word_file = f"data/context_biasing/{subset}_rare_words_{min_count}.txt"
if os.path.exists(rare_word_file):
print("File exists, do not proceed!")
return
print(f"Finding rare words in the manifest.")
print("Finding rare words in the manifest.")
count_file = f"data/tmp/transcript_words_{subset}_count.txt"
if not os.path.exists(count_file):
with open(txt_path, "r") as file:
@ -90,27 +93,28 @@ def get_rare_words(subset: str, min_count: int):
word_count[w] = 1
else:
word_count[w] += 1
with open(count_file, 'w') as fout:
with open(count_file, "w") as fout:
for w in word_count:
fout.write(f"{w}\t{word_count[w]}\n")
else:
word_count = {}
with open(count_file, 'r') as fin:
word_count = fin.read().strip().split('\n')
word_count = [pair.split('\t') for pair in word_count]
with open(count_file, "r") as fin:
word_count = fin.read().strip().split("\n")
word_count = [pair.split("\t") for pair in word_count]
word_count = sorted(word_count, key=lambda w: int(w[1]), reverse=True)
print(f"A total of {len(word_count)} words appeared!")
rare_words = []
for k in word_count:
if int(word_count[k]) <= min_count:
rare_words.append(k+"\n")
rare_words.append(k + "\n")
print(f"A total of {len(rare_words)} appeared <= {min_count} times")
with open(rare_word_file, 'w') as f:
with open(rare_word_file, "w") as f:
f.writelines(rare_words)
def add_context_list_to_manifest(subset: str, min_count: int):
"""Generate a context list of rare words for each utterance in the manifest
@ -121,24 +125,30 @@ def add_context_list_to_manifest(subset: str, min_count: int):
"""
rare_words_file = f"data/context_biasing/{subset}_rare_words_{min_count}.txt"
manifest_dir = f"data/fbank/libriheavy_cuts_{subset}.jsonl.gz"
target_manifest_dir = manifest_dir.replace(".jsonl.gz", f"_with_context_list_min_count_{min_count}.jsonl.gz")
target_manifest_dir = manifest_dir.replace(
".jsonl.gz", f"_with_context_list_min_count_{min_count}.jsonl.gz"
)
if os.path.exists(target_manifest_dir):
print(f"Target file exits at {target_manifest_dir}!")
return
print(f"Reading rare words from {rare_words_file}")
with open(rare_words_file, "r") as f:
rare_words = f.read()
rare_words = rare_words.split("\n")
rare_words = set(rare_words)
print(f"A total of {len(rare_words)} rare words!")
cuts = load_manifest_lazy(manifest_dir)
print(f"Loaded manifest from {manifest_dir}")
def _add_context(c: Cut):
splits = remove_non_alphabetic(c.supervisions[0].texts[0], strict=False).upper().split()
splits = (
remove_non_alphabetic(c.supervisions[0].texts[0], strict=False)
.upper()
.split()
)
found = []
for w in splits:
if w in rare_words:
@ -147,7 +157,7 @@ def add_context_list_to_manifest(subset: str, min_count: int):
return c
cuts = cuts.map(_add_context)
cuts.to_file(target_manifest_dir)
print(f"Saved manifest with context list to {target_manifest_dir}")
@ -161,7 +171,9 @@ def check(subset: str, min_count: int):
has_context_list = [c.supervisions[0].context_list != "" for c in cuts]
context_list_len = [len(c.supervisions[0].context_list.split()) for c in cuts]
print(f"{sum(has_context_list)}/{total_cuts} cuts have context list! ")
print(f"Average length of non-empty context list is {sum(context_list_len)/sum(has_context_list)}")
print(
f"Average length of non-empty context list is {sum(context_list_len)/sum(has_context_list)}"
)
def write_error_stats(
@ -218,24 +230,24 @@ def write_error_stats(
words: Dict[str, List[int]] = defaultdict(lambda: [0, 0, 0, 0, 0])
num_corr = 0
ERR = "*"
if compute_CER:
for i, res in enumerate(results):
cut_id, ref, hyp = res
ref = list("".join(ref))
hyp = list("".join(hyp))
results[i] = (cut_id, ref, hyp)
for cut_id, ref, hyp in results:
ali = kaldialign.align(ref, hyp, ERR)
for ref_word, hyp_word in ali:
if ref_word == ERR: # INSERTION
if ref_word == ERR: # INSERTION
ins[hyp_word] += 1
words[hyp_word][3] += 1
elif hyp_word == ERR: # DELETION
elif hyp_word == ERR: # DELETION
dels[ref_word] += 1
words[ref_word][4] += 1
elif hyp_word != ref_word: # SUBSTITUTION
elif hyp_word != ref_word: # SUBSTITUTION
subs[(ref_word, hyp_word)] += 1
words[ref_word][1] += 1
words[hyp_word][2] += 1
@ -301,9 +313,7 @@ def write_error_stats(
f"{cut_id}:\t"
+ " ".join(
(
ref_word
if ref_word == hyp_word
else f"({ref_word}->{hyp_word})"
ref_word if ref_word == hyp_word else f"({ref_word}->{hyp_word})"
for ref_word, hyp_word in ali
)
),
@ -313,9 +323,7 @@ def write_error_stats(
print("", file=f)
print("SUBSTITUTIONS: count ref -> hyp", file=f)
for count, (ref, hyp) in sorted(
[(v, k) for k, v in subs.items()], reverse=True
):
for count, (ref, hyp) in sorted([(v, k) for k, v in subs.items()], reverse=True):
print(f"{count} {ref} -> {hyp}", file=f)
print("", file=f)
@ -332,11 +340,9 @@ def write_error_stats(
unbiased_word_errs = 0
biased_word_counts = 0
biased_word_errs = 0
print("", file=f)
print(
"PER-WORD STATS: word corr tot_errs count_in_ref count_in_hyp", file=f
)
print("PER-WORD STATS: word corr tot_errs count_in_ref count_in_hyp", file=f)
for _, word, counts in sorted(
[(sum(v[1:]), k, v) for k, v in words.items()], reverse=True
@ -344,37 +350,36 @@ def write_error_stats(
(corr, ref_sub, hyp_sub, ins, dels) = counts
tot_errs = ref_sub + hyp_sub + ins + dels
# number of appearances of "word" in reference text
ref_count = corr + ref_sub + dels # correct + in ref but got substituted + deleted
ref_count = (
corr + ref_sub + dels
) # correct + in ref but got substituted + deleted
# number of appearances of "word" in hyp text
hyp_count = corr + hyp_sub + ins
if biasing_words is not None:
if word in biasing_words:
biased_word_counts += ref_count
biased_word_errs += (ins + dels + ref_sub)
biased_word_errs += ins + dels + ref_sub
else:
unbiased_word_counts += ref_count
unbiased_word_errs += (ins + dels + hyp_sub)
unbiased_word_errs += ins + dels + hyp_sub
print(f"{word} {corr} {tot_errs} {ref_count} {hyp_count}", file=f)
if biasing_words is not None:
B_WER = "%.2f" % (100 *biased_word_errs/biased_word_counts)
U_WER = "%.2f" % (100 *unbiased_word_errs/unbiased_word_counts)
B_WER = "%.2f" % (100 * biased_word_errs / biased_word_counts)
U_WER = "%.2f" % (100 * unbiased_word_errs / unbiased_word_counts)
logging.info(f"Biased WER: {B_WER} [{biased_word_errs}/{biased_word_counts}] ")
logging.info(f"Un-biased WER: {U_WER} [{unbiased_word_errs}/{unbiased_word_counts}]")
logging.info(
f"Un-biased WER: {U_WER} [{unbiased_word_errs}/{unbiased_word_counts}]"
)
return float(tot_err_rate)
if __name__=="__main__":
#test_set = "test-clean"
#get_facebook_biasing_list(test_set)
if __name__ == "__main__":
subset = "medium"
min_count = 460
#get_rare_words(subset, min_count)
min_count = 10
get_rare_words(subset, min_count)
add_context_list_to_manifest(subset=subset, min_count=min_count)
check(subset=subset, min_count=min_count)
check(subset=subset, min_count=min_count)