add more normalizations such as number/year to words; fix a few bugs when feeding input to WER computation

This commit is contained in:
marcoyang1998 2023-07-20 15:50:50 +08:00
parent 5532bb1683
commit 754ac00509
2 changed files with 156 additions and 15 deletions

View File

@ -118,7 +118,12 @@ from beam_search import (
greedy_search_batch,
modified_beam_search,
)
from text_normalization import simple_normalization, upper_normalization
from lhotse.cut import Cut
from text_normalization import (
simple_normalization,
upper_normalization,
word_normalization,
)
from train import add_model_arguments, get_model, get_params
from icefall.checkpoint import (
@ -802,14 +807,29 @@ def main():
args.return_cuts = True
libriheavy = LibriHeavyAsrDataModule(args)
def add_texts(c: Cut):
text = c.supervisions[0].text
c.supervisions[0].texts = [text]
return c
test_clean_cuts = libriheavy.test_clean_cuts()
test_other_cuts = libriheavy.test_other_cuts()
ls_test_clean_cuts = libriheavy.librispeech_test_clean_cuts()
ls_test_other_cuts = libriheavy.librispeech_test_other_cuts()
ls_test_clean_cuts = ls_test_clean_cuts.map(add_texts)
ls_test_other_cuts = ls_test_other_cuts.map(add_texts)
test_clean_dl = libriheavy.test_dataloaders(test_clean_cuts)
test_other_dl = libriheavy.test_dataloaders(test_other_cuts)
ls_test_clean_dl = libriheavy.test_dataloaders(ls_test_clean_cuts)
ls_test_other_dl = libriheavy.test_dataloaders(ls_test_other_cuts)
test_sets = ["libriheavy-test-clean", "libriheavy-test-other"]
test_dl = [test_clean_dl, test_other_dl]
# test_sets = ["libriheavy-test-clean", "libriheavy-test-other", "librispeech-test-clean", "librispeech-test-other"]
# test_dl = [test_clean_dl, test_other_dl, ls_test_clean_dl, ls_test_other_dl]
test_sets = ["librispeech-test-clean", "librispeech-test-other"]
test_dl = [ls_test_clean_dl, ls_test_other_dl]
for test_set, test_dl in zip(test_sets, test_dl):
results_dict = decode_dataset(
@ -834,12 +854,12 @@ def main():
for k in results_dict:
new_ans = []
for item in results_dict[k]:
id, hyp, ref = item
hyp = [upper_normalization(w.upper()) for w in hyp]
id, ref, hyp = item
hyp = upper_normalization(" ".join(hyp)).split()
hyp = [word_normalization(w) for w in hyp]
hyp = " ".join(hyp).split()
hyp = [w for w in hyp if w != ""]
ref = [upper_normalization(w.upper()) for w in ref]
ref = [w for w in ref if w != ""]
new_ans.append((id, hyp, ref))
new_ans.append((id, ref, hyp))
new_res[k] = new_ans
save_results(

View File

@ -1,5 +1,101 @@
import re
words = {
0: "zero",
1: "one",
2: "two",
3: "three",
4: "four",
5: "five",
6: "six",
7: "seven",
8: "eight",
9: "nine",
10: "ten",
11: "eleven",
12: "twelve",
13: "thirteen",
14: "fourteen",
15: "fifteen",
16: "sixteen",
17: "seventeen",
18: "eighteen",
19: "nineteen",
20: "twenty",
30: "thirty",
40: "forty",
50: "fifty",
60: "sixty",
70: "seventy",
80: "eighty",
90: "ninety",
}
ordinal_nums = [
"zeroth",
"first",
"second",
"third",
"fourth",
"fifth",
"sixth",
"seventh",
"eighth",
"ninth",
"tenth",
"eleventh",
"twelfth",
"thirteenth",
"fourteenth",
"fifteenth",
"sixteenth",
"seventeenth",
"eighteenth",
"nineteenth",
"twentieth",
]
num_ordinal_dict = {num: ordinal_nums[num] for num in range(21)}
def year_to_words(num: int):
assert isinstance(num, int), num
# check if a num is representing a year
if num > 1500 and num < 2000:
return words[num // 100] + " " + num_to_words(num % 100)
elif num == 2000:
return "TWO THOUSAND"
elif num > 2000:
return "TWO THOUSAND AND " + num_to_words(num % 100)
else:
return num_to_words(num)
def num_to_words(num: int):
# Return the English words of a integer number
# If this is a year number
if num > 1500 and num < 2030:
return year_to_words(num)
if num < 20:
return words[num]
if num < 100:
if num % 10 == 0:
return words[num // 10 * 10]
else:
return words[num // 10 * 10] + " " + words[num % 10]
if num < 1000:
return words[num // 100] + " hundred and " + num_to_words(num % 100)
if num < 1000000:
return num_to_words(num // 1000) + " thousand " + num_to_words(num % 1000)
return num
def num_to_ordinal_word(num: int):
return num_ordinal_dict.get(num, num_to_words(num)).upper()
def replace_full_width_symbol(s: str) -> str:
# replace full-width symbol with theri half width counterpart
s = s.replace("", '"')
@ -10,18 +106,43 @@ def replace_full_width_symbol(s: str) -> str:
return s
def upper_ref_text(text: str) -> str:
def upper_normalization(text: str) -> str:
text = replace_full_width_symbol(text)
text = text.upper() # upper case all characters
text = text.upper() # upper case all characters
# Only keep all alpha-numeric characters, hypen and apostrophe
text = text.replace("--", " ")
text = re.sub("[^a-zA-Z0-9\s\'-]+", "", text)
text = text.replace("-", " ")
text = re.sub("[^a-zA-Z0-9\s']+", "", text)
return text
def word_normalization(word: str) -> str:
if word == "MRS":
return "MISSUS"
if word == "MR":
return "MISTER"
if word == "ST":
return "SAINT"
if word == "ECT":
return "ET CETERA"
if word.isnumeric():
word = num_to_words(int(word))
return word.upper()
if word[-2:] == "TH" and word[0].isnumeric(): # e.g 9TH, 6TH
return num_to_ordinal_word(int(word[:-2])).upper()
return word
def simple_normalization(text: str) -> str:
text = replace_full_width_symbol(text)
text = text.replace("--", " ")
return text
if __name__ == "__main__":
s = str(1830)
out = word_normalization(s)
print(s, out)