mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
update
This commit is contained in:
parent
a0fe6bcd0d
commit
d411ffb4b6
File diff suppressed because it is too large
Load Diff
1
egs/libriheavy/ASR/zipformer_prompt_asr/beam_search.py
Symbolic link
1
egs/libriheavy/ASR/zipformer_prompt_asr/beam_search.py
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../../../librispeech/ASR/pruned_transducer_stateless2/beam_search.py
|
||||||
@ -1,14 +1,16 @@
|
|||||||
from typing import Dict, List, Tuple, TextIO, Union, Iterable
|
|
||||||
import ast
|
import ast
|
||||||
|
import glob
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
from typing import Dict, Iterable, List, TextIO, Tuple, Union
|
||||||
|
|
||||||
|
import kaldialign
|
||||||
from lhotse import load_manifest, load_manifest_lazy
|
from lhotse import load_manifest, load_manifest_lazy
|
||||||
from lhotse.cut import Cut, CutSet
|
from lhotse.cut import Cut, CutSet
|
||||||
from text_normalization import remove_non_alphabetic
|
from text_normalization import remove_non_alphabetic
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
import os
|
|
||||||
|
|
||||||
import kaldialign
|
|
||||||
import logging
|
|
||||||
|
|
||||||
def get_facebook_biasing_list(
|
def get_facebook_biasing_list(
|
||||||
test_set: str,
|
test_set: str,
|
||||||
@ -16,67 +18,68 @@ def get_facebook_biasing_list(
|
|||||||
num_distractors: int = 100,
|
num_distractors: int = 100,
|
||||||
) -> Dict:
|
) -> Dict:
|
||||||
# Get the biasing list from the meta paper: https://arxiv.org/pdf/2104.02194.pdf
|
# Get the biasing list from the meta paper: https://arxiv.org/pdf/2104.02194.pdf
|
||||||
assert num_distractors in (0, 100,500,1000,2000), num_distractors
|
assert num_distractors in (0, 100, 500, 1000, 2000), num_distractors
|
||||||
if num_distractors == 0:
|
if num_distractors == 0:
|
||||||
if test_set == "test-clean":
|
if test_set == "test-clean":
|
||||||
biasing_file = f"data/context_biasing/fbai-speech/is21_deep_bias/ref/test-clean.biasing_100.tsv"
|
biasing_file = "data/context_biasing/fbai-speech/is21_deep_bias/ref/test-clean.biasing_100.tsv"
|
||||||
elif test_set == "test-other":
|
elif test_set == "test-other":
|
||||||
biasing_file = f"data/context_biasing/fbai-speech/is21_deep_bias/ref/test-other.biasing_100.tsv"
|
biasing_file = "data/context_biasing/fbai-speech/is21_deep_bias/ref/test-other.biasing_100.tsv"
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unseen test set {test_set}")
|
raise ValueError(f"Unseen test set {test_set}")
|
||||||
else:
|
else:
|
||||||
if test_set == "test-clean":
|
if test_set == "test-clean":
|
||||||
biasing_file = f"data/context_biasing/fbai-speech/is21_deep_bias/ref/test-clean.biasing_{num_distractors}.tsv"
|
biasing_file = "data/context_biasing/fbai-speech/is21_deep_bias/ref/test-clean.biasing_{num_distractors}.tsv"
|
||||||
elif test_set == "test-other":
|
elif test_set == "test-other":
|
||||||
biasing_file = f"data/context_biasing/fbai-speech/is21_deep_bias/ref/test-other.biasing_{num_distractors}.tsv"
|
biasing_file = "data/context_biasing/fbai-speech/is21_deep_bias/ref/test-other.biasing_{num_distractors}.tsv"
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unseen test set {test_set}")
|
raise ValueError(f"Unseen test set {test_set}")
|
||||||
|
|
||||||
f = open(biasing_file, 'r')
|
f = open(biasing_file, "r")
|
||||||
data = f.readlines()
|
data = f.readlines()
|
||||||
f.close()
|
f.close()
|
||||||
|
|
||||||
output = dict()
|
output = dict()
|
||||||
for line in data:
|
for line in data:
|
||||||
id, _, l1, l2 = line.split('\t')
|
id, _, l1, l2 = line.split("\t")
|
||||||
if use_distractors:
|
if use_distractors:
|
||||||
biasing_list = ast.literal_eval(l2)
|
biasing_list = ast.literal_eval(l2)
|
||||||
else:
|
else:
|
||||||
biasing_list = ast.literal_eval(l1)
|
biasing_list = ast.literal_eval(l1)
|
||||||
biasing_list = [w.strip().upper() for w in biasing_list]
|
biasing_list = [w.strip().upper() for w in biasing_list]
|
||||||
output[id] = " ".join(biasing_list)
|
output[id] = " ".join(biasing_list)
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
def brian_biasing_list(level: str):
|
def brian_biasing_list(level: str):
|
||||||
# The biasing list from Brian's paper: https://arxiv.org/pdf/2109.00627.pdf
|
# The biasing list from Brian's paper: https://arxiv.org/pdf/2109.00627.pdf
|
||||||
import glob
|
|
||||||
root_dir = f"data/context_biasing/LibriSpeechBiasingLists/{level}Level"
|
root_dir = f"data/context_biasing/LibriSpeechBiasingLists/{level}Level"
|
||||||
all_files = glob.glob(root_dir + "/*")
|
all_files = glob.glob(root_dir + "/*")
|
||||||
biasing_dict = {}
|
biasing_dict = {}
|
||||||
for f in all_files:
|
for f in all_files:
|
||||||
k = f.split('/')[-1]
|
k = f.split("/")[-1]
|
||||||
fin = open(f, 'r')
|
fin = open(f, "r")
|
||||||
data = fin.read().strip().split()
|
data = fin.read().strip().split()
|
||||||
biasing_dict[k] = " ".join(data)
|
biasing_dict[k] = " ".join(data)
|
||||||
fin.close()
|
fin.close()
|
||||||
|
|
||||||
return biasing_dict
|
return biasing_dict
|
||||||
|
|
||||||
|
|
||||||
def get_rare_words(subset: str, min_count: int):
|
def get_rare_words(subset: str, min_count: int):
|
||||||
"""Get a list of rare words appearing less than `min_count` times
|
"""Get a list of rare words appearing less than `min_count` times
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
subset:
|
subset: The dataset
|
||||||
min_count (int): Count of appearance
|
min_count (int): Count of appearance
|
||||||
"""
|
"""
|
||||||
txt_path = f"data/tmp/transcript_words_{subset}.txt"
|
txt_path = f"data/tmp/transcript_words_{subset}.txt"
|
||||||
rare_word_file = f"data/context_biasing/{subset}_rare_words_{min_count}.txt"
|
rare_word_file = f"data/context_biasing/{subset}_rare_words_{min_count}.txt"
|
||||||
|
|
||||||
if os.path.exists(rare_word_file):
|
if os.path.exists(rare_word_file):
|
||||||
print("File exists, do not proceed!")
|
print("File exists, do not proceed!")
|
||||||
return
|
return
|
||||||
print(f"Finding rare words in the manifest.")
|
print("Finding rare words in the manifest.")
|
||||||
count_file = f"data/tmp/transcript_words_{subset}_count.txt"
|
count_file = f"data/tmp/transcript_words_{subset}_count.txt"
|
||||||
if not os.path.exists(count_file):
|
if not os.path.exists(count_file):
|
||||||
with open(txt_path, "r") as file:
|
with open(txt_path, "r") as file:
|
||||||
@ -90,27 +93,28 @@ def get_rare_words(subset: str, min_count: int):
|
|||||||
word_count[w] = 1
|
word_count[w] = 1
|
||||||
else:
|
else:
|
||||||
word_count[w] += 1
|
word_count[w] += 1
|
||||||
|
|
||||||
with open(count_file, 'w') as fout:
|
with open(count_file, "w") as fout:
|
||||||
for w in word_count:
|
for w in word_count:
|
||||||
fout.write(f"{w}\t{word_count[w]}\n")
|
fout.write(f"{w}\t{word_count[w]}\n")
|
||||||
else:
|
else:
|
||||||
word_count = {}
|
word_count = {}
|
||||||
with open(count_file, 'r') as fin:
|
with open(count_file, "r") as fin:
|
||||||
word_count = fin.read().strip().split('\n')
|
word_count = fin.read().strip().split("\n")
|
||||||
word_count = [pair.split('\t') for pair in word_count]
|
word_count = [pair.split("\t") for pair in word_count]
|
||||||
word_count = sorted(word_count, key=lambda w: int(w[1]), reverse=True)
|
word_count = sorted(word_count, key=lambda w: int(w[1]), reverse=True)
|
||||||
|
|
||||||
print(f"A total of {len(word_count)} words appeared!")
|
print(f"A total of {len(word_count)} words appeared!")
|
||||||
rare_words = []
|
rare_words = []
|
||||||
for k in word_count:
|
for k in word_count:
|
||||||
if int(word_count[k]) <= min_count:
|
if int(word_count[k]) <= min_count:
|
||||||
rare_words.append(k+"\n")
|
rare_words.append(k + "\n")
|
||||||
print(f"A total of {len(rare_words)} appeared <= {min_count} times")
|
print(f"A total of {len(rare_words)} appeared <= {min_count} times")
|
||||||
|
|
||||||
with open(rare_word_file, 'w') as f:
|
with open(rare_word_file, "w") as f:
|
||||||
f.writelines(rare_words)
|
f.writelines(rare_words)
|
||||||
|
|
||||||
|
|
||||||
def add_context_list_to_manifest(subset: str, min_count: int):
|
def add_context_list_to_manifest(subset: str, min_count: int):
|
||||||
"""Generate a context list of rare words for each utterance in the manifest
|
"""Generate a context list of rare words for each utterance in the manifest
|
||||||
|
|
||||||
@ -121,24 +125,30 @@ def add_context_list_to_manifest(subset: str, min_count: int):
|
|||||||
"""
|
"""
|
||||||
rare_words_file = f"data/context_biasing/{subset}_rare_words_{min_count}.txt"
|
rare_words_file = f"data/context_biasing/{subset}_rare_words_{min_count}.txt"
|
||||||
manifest_dir = f"data/fbank/libriheavy_cuts_{subset}.jsonl.gz"
|
manifest_dir = f"data/fbank/libriheavy_cuts_{subset}.jsonl.gz"
|
||||||
|
|
||||||
target_manifest_dir = manifest_dir.replace(".jsonl.gz", f"_with_context_list_min_count_{min_count}.jsonl.gz")
|
target_manifest_dir = manifest_dir.replace(
|
||||||
|
".jsonl.gz", f"_with_context_list_min_count_{min_count}.jsonl.gz"
|
||||||
|
)
|
||||||
if os.path.exists(target_manifest_dir):
|
if os.path.exists(target_manifest_dir):
|
||||||
print(f"Target file exits at {target_manifest_dir}!")
|
print(f"Target file exits at {target_manifest_dir}!")
|
||||||
return
|
return
|
||||||
|
|
||||||
print(f"Reading rare words from {rare_words_file}")
|
print(f"Reading rare words from {rare_words_file}")
|
||||||
with open(rare_words_file, "r") as f:
|
with open(rare_words_file, "r") as f:
|
||||||
rare_words = f.read()
|
rare_words = f.read()
|
||||||
rare_words = rare_words.split("\n")
|
rare_words = rare_words.split("\n")
|
||||||
rare_words = set(rare_words)
|
rare_words = set(rare_words)
|
||||||
print(f"A total of {len(rare_words)} rare words!")
|
print(f"A total of {len(rare_words)} rare words!")
|
||||||
|
|
||||||
cuts = load_manifest_lazy(manifest_dir)
|
cuts = load_manifest_lazy(manifest_dir)
|
||||||
print(f"Loaded manifest from {manifest_dir}")
|
print(f"Loaded manifest from {manifest_dir}")
|
||||||
|
|
||||||
def _add_context(c: Cut):
|
def _add_context(c: Cut):
|
||||||
splits = remove_non_alphabetic(c.supervisions[0].texts[0], strict=False).upper().split()
|
splits = (
|
||||||
|
remove_non_alphabetic(c.supervisions[0].texts[0], strict=False)
|
||||||
|
.upper()
|
||||||
|
.split()
|
||||||
|
)
|
||||||
found = []
|
found = []
|
||||||
for w in splits:
|
for w in splits:
|
||||||
if w in rare_words:
|
if w in rare_words:
|
||||||
@ -147,7 +157,7 @@ def add_context_list_to_manifest(subset: str, min_count: int):
|
|||||||
return c
|
return c
|
||||||
|
|
||||||
cuts = cuts.map(_add_context)
|
cuts = cuts.map(_add_context)
|
||||||
|
|
||||||
cuts.to_file(target_manifest_dir)
|
cuts.to_file(target_manifest_dir)
|
||||||
print(f"Saved manifest with context list to {target_manifest_dir}")
|
print(f"Saved manifest with context list to {target_manifest_dir}")
|
||||||
|
|
||||||
@ -161,7 +171,9 @@ def check(subset: str, min_count: int):
|
|||||||
has_context_list = [c.supervisions[0].context_list != "" for c in cuts]
|
has_context_list = [c.supervisions[0].context_list != "" for c in cuts]
|
||||||
context_list_len = [len(c.supervisions[0].context_list.split()) for c in cuts]
|
context_list_len = [len(c.supervisions[0].context_list.split()) for c in cuts]
|
||||||
print(f"{sum(has_context_list)}/{total_cuts} cuts have context list! ")
|
print(f"{sum(has_context_list)}/{total_cuts} cuts have context list! ")
|
||||||
print(f"Average length of non-empty context list is {sum(context_list_len)/sum(has_context_list)}")
|
print(
|
||||||
|
f"Average length of non-empty context list is {sum(context_list_len)/sum(has_context_list)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def write_error_stats(
|
def write_error_stats(
|
||||||
@ -218,24 +230,24 @@ def write_error_stats(
|
|||||||
words: Dict[str, List[int]] = defaultdict(lambda: [0, 0, 0, 0, 0])
|
words: Dict[str, List[int]] = defaultdict(lambda: [0, 0, 0, 0, 0])
|
||||||
num_corr = 0
|
num_corr = 0
|
||||||
ERR = "*"
|
ERR = "*"
|
||||||
|
|
||||||
if compute_CER:
|
if compute_CER:
|
||||||
for i, res in enumerate(results):
|
for i, res in enumerate(results):
|
||||||
cut_id, ref, hyp = res
|
cut_id, ref, hyp = res
|
||||||
ref = list("".join(ref))
|
ref = list("".join(ref))
|
||||||
hyp = list("".join(hyp))
|
hyp = list("".join(hyp))
|
||||||
results[i] = (cut_id, ref, hyp)
|
results[i] = (cut_id, ref, hyp)
|
||||||
|
|
||||||
for cut_id, ref, hyp in results:
|
for cut_id, ref, hyp in results:
|
||||||
ali = kaldialign.align(ref, hyp, ERR)
|
ali = kaldialign.align(ref, hyp, ERR)
|
||||||
for ref_word, hyp_word in ali:
|
for ref_word, hyp_word in ali:
|
||||||
if ref_word == ERR: # INSERTION
|
if ref_word == ERR: # INSERTION
|
||||||
ins[hyp_word] += 1
|
ins[hyp_word] += 1
|
||||||
words[hyp_word][3] += 1
|
words[hyp_word][3] += 1
|
||||||
elif hyp_word == ERR: # DELETION
|
elif hyp_word == ERR: # DELETION
|
||||||
dels[ref_word] += 1
|
dels[ref_word] += 1
|
||||||
words[ref_word][4] += 1
|
words[ref_word][4] += 1
|
||||||
elif hyp_word != ref_word: # SUBSTITUTION
|
elif hyp_word != ref_word: # SUBSTITUTION
|
||||||
subs[(ref_word, hyp_word)] += 1
|
subs[(ref_word, hyp_word)] += 1
|
||||||
words[ref_word][1] += 1
|
words[ref_word][1] += 1
|
||||||
words[hyp_word][2] += 1
|
words[hyp_word][2] += 1
|
||||||
@ -301,9 +313,7 @@ def write_error_stats(
|
|||||||
f"{cut_id}:\t"
|
f"{cut_id}:\t"
|
||||||
+ " ".join(
|
+ " ".join(
|
||||||
(
|
(
|
||||||
ref_word
|
ref_word if ref_word == hyp_word else f"({ref_word}->{hyp_word})"
|
||||||
if ref_word == hyp_word
|
|
||||||
else f"({ref_word}->{hyp_word})"
|
|
||||||
for ref_word, hyp_word in ali
|
for ref_word, hyp_word in ali
|
||||||
)
|
)
|
||||||
),
|
),
|
||||||
@ -313,9 +323,7 @@ def write_error_stats(
|
|||||||
print("", file=f)
|
print("", file=f)
|
||||||
print("SUBSTITUTIONS: count ref -> hyp", file=f)
|
print("SUBSTITUTIONS: count ref -> hyp", file=f)
|
||||||
|
|
||||||
for count, (ref, hyp) in sorted(
|
for count, (ref, hyp) in sorted([(v, k) for k, v in subs.items()], reverse=True):
|
||||||
[(v, k) for k, v in subs.items()], reverse=True
|
|
||||||
):
|
|
||||||
print(f"{count} {ref} -> {hyp}", file=f)
|
print(f"{count} {ref} -> {hyp}", file=f)
|
||||||
|
|
||||||
print("", file=f)
|
print("", file=f)
|
||||||
@ -332,11 +340,9 @@ def write_error_stats(
|
|||||||
unbiased_word_errs = 0
|
unbiased_word_errs = 0
|
||||||
biased_word_counts = 0
|
biased_word_counts = 0
|
||||||
biased_word_errs = 0
|
biased_word_errs = 0
|
||||||
|
|
||||||
print("", file=f)
|
print("", file=f)
|
||||||
print(
|
print("PER-WORD STATS: word corr tot_errs count_in_ref count_in_hyp", file=f)
|
||||||
"PER-WORD STATS: word corr tot_errs count_in_ref count_in_hyp", file=f
|
|
||||||
)
|
|
||||||
|
|
||||||
for _, word, counts in sorted(
|
for _, word, counts in sorted(
|
||||||
[(sum(v[1:]), k, v) for k, v in words.items()], reverse=True
|
[(sum(v[1:]), k, v) for k, v in words.items()], reverse=True
|
||||||
@ -344,37 +350,36 @@ def write_error_stats(
|
|||||||
(corr, ref_sub, hyp_sub, ins, dels) = counts
|
(corr, ref_sub, hyp_sub, ins, dels) = counts
|
||||||
tot_errs = ref_sub + hyp_sub + ins + dels
|
tot_errs = ref_sub + hyp_sub + ins + dels
|
||||||
# number of appearances of "word" in reference text
|
# number of appearances of "word" in reference text
|
||||||
ref_count = corr + ref_sub + dels # correct + in ref but got substituted + deleted
|
ref_count = (
|
||||||
|
corr + ref_sub + dels
|
||||||
|
) # correct + in ref but got substituted + deleted
|
||||||
# number of appearances of "word" in hyp text
|
# number of appearances of "word" in hyp text
|
||||||
hyp_count = corr + hyp_sub + ins
|
hyp_count = corr + hyp_sub + ins
|
||||||
|
|
||||||
|
|
||||||
if biasing_words is not None:
|
if biasing_words is not None:
|
||||||
if word in biasing_words:
|
if word in biasing_words:
|
||||||
biased_word_counts += ref_count
|
biased_word_counts += ref_count
|
||||||
biased_word_errs += (ins + dels + ref_sub)
|
biased_word_errs += ins + dels + ref_sub
|
||||||
else:
|
else:
|
||||||
unbiased_word_counts += ref_count
|
unbiased_word_counts += ref_count
|
||||||
unbiased_word_errs += (ins + dels + hyp_sub)
|
unbiased_word_errs += ins + dels + hyp_sub
|
||||||
|
|
||||||
print(f"{word} {corr} {tot_errs} {ref_count} {hyp_count}", file=f)
|
print(f"{word} {corr} {tot_errs} {ref_count} {hyp_count}", file=f)
|
||||||
|
|
||||||
if biasing_words is not None:
|
if biasing_words is not None:
|
||||||
B_WER = "%.2f" % (100 *biased_word_errs/biased_word_counts)
|
B_WER = "%.2f" % (100 * biased_word_errs / biased_word_counts)
|
||||||
U_WER = "%.2f" % (100 *unbiased_word_errs/unbiased_word_counts)
|
U_WER = "%.2f" % (100 * unbiased_word_errs / unbiased_word_counts)
|
||||||
logging.info(f"Biased WER: {B_WER} [{biased_word_errs}/{biased_word_counts}] ")
|
logging.info(f"Biased WER: {B_WER} [{biased_word_errs}/{biased_word_counts}] ")
|
||||||
logging.info(f"Un-biased WER: {U_WER} [{unbiased_word_errs}/{unbiased_word_counts}]")
|
logging.info(
|
||||||
|
f"Un-biased WER: {U_WER} [{unbiased_word_errs}/{unbiased_word_counts}]"
|
||||||
|
)
|
||||||
|
|
||||||
return float(tot_err_rate)
|
return float(tot_err_rate)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
if __name__=="__main__":
|
|
||||||
#test_set = "test-clean"
|
|
||||||
#get_facebook_biasing_list(test_set)
|
|
||||||
subset = "medium"
|
subset = "medium"
|
||||||
min_count = 460
|
min_count = 10
|
||||||
#get_rare_words(subset, min_count)
|
get_rare_words(subset, min_count)
|
||||||
add_context_list_to_manifest(subset=subset, min_count=min_count)
|
add_context_list_to_manifest(subset=subset, min_count=min_count)
|
||||||
check(subset=subset, min_count=min_count)
|
check(subset=subset, min_count=min_count)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user