From 754ac00509cdbac4a9a5d5615571c7d8dd132a4f Mon Sep 17 00:00:00 2001 From: marcoyang1998 Date: Thu, 20 Jul 2023 15:50:50 +0800 Subject: [PATCH] add more normalizations such as number/year to words; fix a few bugs when feeding input to WER computation --- egs/libriheavy/ASR/zipformer/decode.py | 36 +++-- .../ASR/zipformer/text_normalization.py | 135 +++++++++++++++++- 2 files changed, 156 insertions(+), 15 deletions(-) diff --git a/egs/libriheavy/ASR/zipformer/decode.py b/egs/libriheavy/ASR/zipformer/decode.py index 089a2f823..a4e28cf5d 100644 --- a/egs/libriheavy/ASR/zipformer/decode.py +++ b/egs/libriheavy/ASR/zipformer/decode.py @@ -118,7 +118,12 @@ from beam_search import ( greedy_search_batch, modified_beam_search, ) -from text_normalization import simple_normalization, upper_normalization +from lhotse.cut import Cut +from text_normalization import ( + simple_normalization, + upper_normalization, + word_normalization, +) from train import add_model_arguments, get_model, get_params from icefall.checkpoint import ( @@ -802,14 +807,29 @@ def main(): args.return_cuts = True libriheavy = LibriHeavyAsrDataModule(args) + def add_texts(c: Cut): + text = c.supervisions[0].text + c.supervisions[0].texts = [text] + return c + test_clean_cuts = libriheavy.test_clean_cuts() test_other_cuts = libriheavy.test_other_cuts() + ls_test_clean_cuts = libriheavy.librispeech_test_clean_cuts() + ls_test_other_cuts = libriheavy.librispeech_test_other_cuts() + + ls_test_clean_cuts = ls_test_clean_cuts.map(add_texts) + ls_test_other_cuts = ls_test_other_cuts.map(add_texts) test_clean_dl = libriheavy.test_dataloaders(test_clean_cuts) test_other_dl = libriheavy.test_dataloaders(test_other_cuts) + ls_test_clean_dl = libriheavy.test_dataloaders(ls_test_clean_cuts) + ls_test_other_dl = libriheavy.test_dataloaders(ls_test_other_cuts) - test_sets = ["libriheavy-test-clean", "libriheavy-test-other"] - test_dl = [test_clean_dl, test_other_dl] + # test_sets = ["libriheavy-test-clean", "libriheavy-test-other", "librispeech-test-clean", "librispeech-test-other"] + # test_dl = [test_clean_dl, test_other_dl, ls_test_clean_dl, ls_test_other_dl] + + test_sets = ["librispeech-test-clean", "librispeech-test-other"] + test_dl = [ls_test_clean_dl, ls_test_other_dl] for test_set, test_dl in zip(test_sets, test_dl): results_dict = decode_dataset( @@ -834,12 +854,12 @@ def main(): for k in results_dict: new_ans = [] for item in results_dict[k]: - id, hyp, ref = item - hyp = [upper_normalization(w.upper()) for w in hyp] + id, ref, hyp = item + hyp = upper_normalization(" ".join(hyp)).split() + hyp = [word_normalization(w) for w in hyp] + hyp = " ".join(hyp).split() hyp = [w for w in hyp if w != ""] - ref = [upper_normalization(w.upper()) for w in ref] - ref = [w for w in ref if w != ""] - new_ans.append((id, hyp, ref)) + new_ans.append((id, ref, hyp)) new_res[k] = new_ans save_results( diff --git a/egs/libriheavy/ASR/zipformer/text_normalization.py b/egs/libriheavy/ASR/zipformer/text_normalization.py index bbde95be0..b27f0ecfc 100644 --- a/egs/libriheavy/ASR/zipformer/text_normalization.py +++ b/egs/libriheavy/ASR/zipformer/text_normalization.py @@ -1,5 +1,101 @@ import re +words = { + 0: "zero", + 1: "one", + 2: "two", + 3: "three", + 4: "four", + 5: "five", + 6: "six", + 7: "seven", + 8: "eight", + 9: "nine", + 10: "ten", + 11: "eleven", + 12: "twelve", + 13: "thirteen", + 14: "fourteen", + 15: "fifteen", + 16: "sixteen", + 17: "seventeen", + 18: "eighteen", + 19: "nineteen", + 20: "twenty", + 30: "thirty", + 40: "forty", + 50: "fifty", + 60: "sixty", + 70: "seventy", + 80: "eighty", + 90: "ninety", +} +ordinal_nums = [ + "zeroth", + "first", + "second", + "third", + "fourth", + "fifth", + "sixth", + "seventh", + "eighth", + "ninth", + "tenth", + "eleventh", + "twelfth", + "thirteenth", + "fourteenth", + "fifteenth", + "sixteenth", + "seventeenth", + "eighteenth", + "nineteenth", + "twentieth", +] + +num_ordinal_dict = {num: ordinal_nums[num] for num in range(21)} + + +def year_to_words(num: int): + assert isinstance(num, int), num + # check if a num is representing a year + if num > 1500 and num < 2000: + return words[num // 100] + " " + num_to_words(num % 100) + elif num == 2000: + return "TWO THOUSAND" + elif num > 2000: + return "TWO THOUSAND AND " + num_to_words(num % 100) + else: + return num_to_words(num) + + +def num_to_words(num: int): + # Return the English words of a integer number + + # If this is a year number + if num > 1500 and num < 2030: + return year_to_words(num) + + if num < 20: + return words[num] + if num < 100: + if num % 10 == 0: + return words[num // 10 * 10] + else: + return words[num // 10 * 10] + " " + words[num % 10] + if num < 1000: + return words[num // 100] + " hundred and " + num_to_words(num % 100) + if num < 1000000: + return num_to_words(num // 1000) + " thousand " + num_to_words(num % 1000) + return num + + +def num_to_ordinal_word(num: int): + + return num_ordinal_dict.get(num, num_to_words(num)).upper() + + def replace_full_width_symbol(s: str) -> str: # replace full-width symbol with theri half width counterpart s = s.replace("“", '"') @@ -10,18 +106,43 @@ def replace_full_width_symbol(s: str) -> str: return s -def upper_ref_text(text: str) -> str: +def upper_normalization(text: str) -> str: text = replace_full_width_symbol(text) - text = text.upper() # upper case all characters - + text = text.upper() # upper case all characters + # Only keep all alpha-numeric characters, hypen and apostrophe - text = text.replace("--", " ") - text = re.sub("[^a-zA-Z0-9\s\'-]+", "", text) + text = text.replace("-", " ") + text = re.sub("[^a-zA-Z0-9\s']+", "", text) return text + +def word_normalization(word: str) -> str: + if word == "MRS": + return "MISSUS" + if word == "MR": + return "MISTER" + if word == "ST": + return "SAINT" + if word == "ECT": + return "ET CETERA" + if word.isnumeric(): + word = num_to_words(int(word)) + return word.upper() + if word[-2:] == "TH" and word[0].isnumeric(): # e.g 9TH, 6TH + return num_to_ordinal_word(int(word[:-2])).upper() + + return word + + def simple_normalization(text: str) -> str: text = replace_full_width_symbol(text) text = text.replace("--", " ") - + return text - \ No newline at end of file + + +if __name__ == "__main__": + + s = str(1830) + out = word_normalization(s) + print(s, out)