From e9ccc0b0738bafa5acbc58b9172948bd57af5a15 Mon Sep 17 00:00:00 2001 From: pkufool Date: Fri, 22 Sep 2023 15:33:29 +0800 Subject: [PATCH] Fix decode.py --- .../local/{prepare_text.py => norm_text.py} | 18 +-- egs/libriheavy/ASR/local/prepare_manifest.py | 7 +- egs/libriheavy/ASR/prepare.sh | 5 +- egs/libriheavy/ASR/zipformer/decode.py | 133 ++++-------------- .../ASR/zipformer/text_normalization.py | 52 +++++++ egs/libriheavy/ASR/zipformer/train.py | 20 ++- requirements-ci.txt | 1 + requirements.txt | 1 + 8 files changed, 99 insertions(+), 138 deletions(-) rename egs/libriheavy/ASR/local/{prepare_text.py => norm_text.py} (74%) create mode 100644 egs/libriheavy/ASR/zipformer/text_normalization.py diff --git a/egs/libriheavy/ASR/local/prepare_text.py b/egs/libriheavy/ASR/local/norm_text.py similarity index 74% rename from egs/libriheavy/ASR/local/prepare_text.py rename to egs/libriheavy/ASR/local/norm_text.py index 9cf59bc7f..99f59a320 100755 --- a/egs/libriheavy/ASR/local/prepare_text.py +++ b/egs/libriheavy/ASR/local/norm_text.py @@ -27,22 +27,9 @@ def get_args(): help="""Path to the input text. """, ) - parser.add_argument( - "--normalize", - action='store_true', - help="""Whether to normalize the text. - True to normalize the text to upper and remove all punctuation. - """ - ) return parser.parse_args() -def simple_cleanup(text: str) -> str: - table = str.maketrans("’‘,。;?!():-《》、“”【】", "'',.;?!(): <>/\"\"[]") - text = text.translate(table) - return text.strip() - - def remove_punc_to_upper(text: str) -> str: text = text.replace("‘", "'") text = text.replace("’", "'") @@ -62,10 +49,7 @@ def main(): sys.stdout = codecs.getwriter("utf-8")(sys.stdout.buffer) line = f.readline() while line: - if args.normalize: - print(remove_punc_to_upper(line)) - else: - print(simple_cleanup(line)) + print(remove_punc_to_upper(line)) line = f.readline() diff --git a/egs/libriheavy/ASR/local/prepare_manifest.py b/egs/libriheavy/ASR/local/prepare_manifest.py index 15ff34ffa..720455e20 100755 --- a/egs/libriheavy/ASR/local/prepare_manifest.py +++ b/egs/libriheavy/ASR/local/prepare_manifest.py @@ -20,6 +20,11 @@ import json import sys from pathlib import Path +def simple_cleanup(text: str) -> str: + table = str.maketrans("’‘,。;?!():-《》、“”【】", "'',.;?!(): <>/\"\"[]") + text = text.translate(table) + return text.strip() + # Assign text of the supervisions and remove unnecessary entries. def main(): assert len(sys.argv) == 3, "Usage: ./local/prepare_manifest.py INPUT OUTPUT_DIR" @@ -28,7 +33,7 @@ def main(): with gzip.open(sys.argv[1], "r") as fin, gzip.open(oname, "w") as fout: for line in fin: cut = json.loads(line) - cut["supervisions"][0]["text"] = cut["supervisions"][0]["custom"]["texts"][0] + cut["supervisions"][0]["text"] = simple_cleanup(cut["supervisions"][0]["custom"]["texts"][0]) del cut["supervisions"][0]["custom"] del cut["custom"] fout.write((json.dumps(cut) + "\n").encode()) diff --git a/egs/libriheavy/ASR/prepare.sh b/egs/libriheavy/ASR/prepare.sh index dd3815b95..e16d7295d 100755 --- a/egs/libriheavy/ASR/prepare.sh +++ b/egs/libriheavy/ASR/prepare.sh @@ -221,7 +221,7 @@ if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then if [ ! -f data/texts ]; then gunzip -c $manifests_dir/libriheavy_cuts_medium.jsonl.gz \ | jq '.supervisions[].text' | sed 's/"//;s/\\//g;s/"$//' \ - | ./local/prepare_text.py --normalize > data/texts + | ./local/norm_text.py > data/texts fi for vocab_size in ${vocab_sizes[@]}; do @@ -244,8 +244,7 @@ if [ $stage -le 10 ] && [ $stop_stage -ge 10 ]; then log "Stage 10: Train BPE model for unnormalized text" if [ ! -f data/punc_texts ]; then gunzip -c $manifests_dir/libriheavy_cuts_medium.jsonl.gz \ - | jq '.supervisions[].text' | sed 's/"//;s/\\//g;s/"$//' \ - | ./local/prepare_text.py > data/punc_texts + | jq '.supervisions[].text' | sed 's/"//;s/\\//g;s/"$//' > data/punc_texts fi for vocab_size in ${vocab_sizes[@]}; do new_vacab_size = $(($vocab_size + 256)) diff --git a/egs/libriheavy/ASR/zipformer/decode.py b/egs/libriheavy/ASR/zipformer/decode.py index a11c79b46..4d95be8bf 100644 --- a/egs/libriheavy/ASR/zipformer/decode.py +++ b/egs/libriheavy/ASR/zipformer/decode.py @@ -119,11 +119,7 @@ from beam_search import ( modified_beam_search, ) from lhotse.cut import Cut -from text_normalization import ( - simple_normalization, - decoding_normalization, - word_normalization, -) +from text_normalization import remove_punc_to_upper, from train import add_model_arguments, get_model, get_params from icefall.checkpoint import ( @@ -141,7 +137,6 @@ from icefall.utils import ( str2bool, write_error_stats, ) -from gigaspeech_scoring import asr_text_post_processing LOG_EPS = math.log(1e-10) @@ -222,9 +217,6 @@ def get_parser(): - fast_beam_search - fast_beam_search_nbest - fast_beam_search_nbest_oracle - - fast_beam_search_nbest_LG - If you use fast_beam_search_nbest_LG, you have to specify - `--lang-dir`, which should contain `LG.pt`. """, ) @@ -250,16 +242,6 @@ def get_parser(): """, ) - parser.add_argument( - "--ngram-lm-scale", - type=float, - default=0.01, - help=""" - Used only when --decoding_method is fast_beam_search_nbest_LG. - It specifies the scale for n-gram LM scores. - """, - ) - parser.add_argument( "--max-contexts", type=int, @@ -310,6 +292,14 @@ def get_parser(): fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", ) + parser.add_argument( + "--train-with-punctuation", + type=str2bool, + default=False, + help="""Set to True, if the model was trained on texts with casing + and punctuation.""" + ) + parser.add_argument( "--post-normalization", type=str2bool, @@ -492,8 +482,6 @@ def decode_one_batch( if "nbest" in params.decoding_method: key += f"_num_paths_{params.num_paths}_" key += f"nbest_scale_{params.nbest_scale}" - if "LG" in params.decoding_method: - key += f"_ngram_lm_scale_{params.ngram_lm_scale}" return {key: hyps} else: @@ -573,6 +561,16 @@ def decode_dataset( results[name].extend(this_batch) + this_batch = [] + if params.post_normalization and params.train_with_punctuation: + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): + ref_words = remove_punc_to_upper(ref_text).split() + hyp_words = remove_punc_to_upper(" ".join(hyp_words)).split() + this_batch.append((cut_id, ref_words, hyp_words)) + + results[f"{name}_norm"].extend(this_batch) + + num_cuts += len(texts) if batch_idx % log_interval == 0: @@ -584,17 +582,6 @@ def decode_dataset( return results -def post_processing( - results: List[Tuple[str, List[str], List[str]]], -) -> List[Tuple[str, List[str], List[str]]]: - new_results = [] - for key, ref, hyp in results: - new_ref = asr_text_post_processing(" ".join(ref)).split() - new_hyp = asr_text_post_processing(" ".join(hyp)).split() - new_results.append((key, new_ref, new_hyp)) - return new_results - - def save_results( params: AttributeDict, test_set_name: str, @@ -605,8 +592,6 @@ def save_results( recog_path = ( params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" ) - if test_set_name == "giga-dev" or test_set_name == "giga-test": - results = post_processing(results) results = sorted(results) store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") @@ -656,7 +641,6 @@ def main(): "beam_search", "fast_beam_search", "fast_beam_search_nbest", - "fast_beam_search_nbest_LG", "fast_beam_search_nbest_oracle", "modified_beam_search", ) @@ -684,8 +668,6 @@ def main(): if "nbest" in params.decoding_method: params.suffix += f"-nbest-scale-{params.nbest_scale}" params.suffix += f"-num-paths-{params.num_paths}" - if "LG" in params.decoding_method: - params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" elif "beam_search" in params.decoding_method: params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" else: @@ -798,21 +780,9 @@ def main(): model.eval() if "fast_beam_search" in params.decoding_method: - if params.decoding_method == "fast_beam_search_nbest_LG": - lexicon = Lexicon(params.lang_dir) - word_table = lexicon.word_table - lg_filename = params.lang_dir / "LG.pt" - logging.info(f"Loading {lg_filename}") - decoding_graph = k2.Fsa.from_dict( - torch.load(lg_filename, map_location=device) - ) - decoding_graph.scores *= params.ngram_lm_scale - else: - word_table = None - decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) else: decoding_graph = None - word_table = None num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") @@ -821,37 +791,23 @@ def main(): args.return_cuts = True libriheavy = LibriHeavyAsrDataModule(args) - def add_texts(c: Cut): - text = c.supervisions[0].text - c.supervisions[0].texts = [text] + def normalize_text(c: Cut): + text = remove_punc_to_upper(c.supervisions[0].text) + c.supervisions[0].text = text return c test_clean_cuts = libriheavy.test_clean_cuts() test_other_cuts = libriheavy.test_other_cuts() - ls_test_clean_cuts = libriheavy.librispeech_test_clean_cuts() - ls_test_other_cuts = libriheavy.librispeech_test_other_cuts() - ls_test_clean_cuts = ls_test_clean_cuts.map(add_texts) - ls_test_other_cuts = ls_test_other_cuts.map(add_texts) - - giga_dev = libriheavy.gigaspeech_dev_cuts() - giga_test = libriheavy.gigaspeech_test_cuts() - giga_dev = giga_dev.map(add_texts) - giga_test = giga_test.map(add_texts) + if not params.train_with_punctuation: + test_clean_cuts = test_clean_cuts.map(normalize_text) + test_other_cuts = test_other_cuts.map(normalize_text) test_clean_dl = libriheavy.test_dataloaders(test_clean_cuts) test_other_dl = libriheavy.test_dataloaders(test_other_cuts) - ls_test_clean_dl = libriheavy.test_dataloaders(ls_test_clean_cuts) - ls_test_other_dl = libriheavy.test_dataloaders(ls_test_other_cuts) - giga_dev_dl = libriheavy.test_dataloaders(giga_dev) - giga_test_dl = libriheavy.test_dataloaders(giga_test) - - # test_sets = ["libriheavy-test-clean", "libriheavy-test-other", "librispeech-test-clean", "librispeech-test-other"] - # test_dl = [test_clean_dl, test_other_dl, ls_test_clean_dl, ls_test_other_dl] - - test_sets = ["giga-test", "giga-dev"] - test_dl = [giga_test_dl, giga_dev_dl] + test_sets = ["test-clean", "test-other"] + test_dl = [test_clean_dl, test_other_dl] for test_set, test_dl in zip(test_sets, test_dl): results_dict = decode_dataset( @@ -869,39 +825,6 @@ def main(): results_dict=results_dict, ) - if params.post_normalization: - params.suffix += "-post-normalization" - - new_res = {} - for k in results_dict: - new_ans = [] - for item in results_dict[k]: - id, ref, hyp = item - if "librispeech" in test_set: - hyp = decoding_normalization(" ".join(hyp)).split() - hyp = [word_normalization(w.upper()) for w in hyp] - hyp = " ".join(hyp).split() - hyp = [w for w in hyp if w != ""] - else: - hyp = decoding_normalization(" ".join(hyp)).split() - hyp = [w.upper() for w in hyp] - hyp = " ".join(hyp).split() - hyp = [w for w in hyp if w != ""] - - ref = decoding_normalization(" ".join(ref)).split() - ref = [w.upper() for w in ref] - ref = " ".join(ref).split() - ref = [w for w in ref if w != ""] - - new_ans.append((id, ref, hyp)) - new_res[k] = new_ans - - save_results( - params=params, - test_set_name=test_set, - results_dict=new_res, - ) - logging.info("Done!") diff --git a/egs/libriheavy/ASR/zipformer/text_normalization.py b/egs/libriheavy/ASR/zipformer/text_normalization.py new file mode 100644 index 000000000..9c9ee6725 --- /dev/null +++ b/egs/libriheavy/ASR/zipformer/text_normalization.py @@ -0,0 +1,52 @@ +from num2words import num2words + + +def remove_punc_to_upper(text: str) -> str: + text = text.replace("‘", "'") + text = text.replace("’", "'") + tokens = set("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789'") + s_list = [x.upper() if x in tokens else " " for x in text] + s = " ".join("".join(s_list).split()).strip() + return s + + +def word_normalization(word: str) -> str: + # 1. Use full word for some abbreviation + # 2. Convert digits to english words + # 3. Convert ordinal number to english words + if word == "MRS": + return "MISSUS" + if word == "MR": + return "MISTER" + if word == "ST": + return "SAINT" + if word == "ECT": + return "ET CETERA" + + if word[-2:] in ("ST", "ND", "RD", "TH") and word[:-2].isnumeric(): # e.g 9TH, 6TH + word = num2words(word[:-2], to="ordinal") + word = word.replace("-", " ") + + if word.isnumeric(): + num = int(word) + if num > 1500 and num < 2030: + word = num2words(word, to="year") + else: + word = num2words(word) + word = word.replace("-", " ") + return word.upper() + + +def text_normalization(text: str) -> str: + text = text.upper() + return " ".join([word_normalization(x) for x in text.split()]) + + +if __name__ == "__main__": + + assert simple_cleanup("I like this 《book>") == "I like this " + assert remove_punc_to_upper("I like this 《book>") == "I LIKE THIS BOOK" + assert ( + text_normalization("Hello Mrs st 21st world 3rd she 99th MR") + == "HELLO MISSUS SAINT TWENTY FIRST WORLD THIRD SHE NINETY NINTH MISTER" + ) diff --git a/egs/libriheavy/ASR/zipformer/train.py b/egs/libriheavy/ASR/zipformer/train.py index 4c6ee976f..b60d7d43b 100644 --- a/egs/libriheavy/ASR/zipformer/train.py +++ b/egs/libriheavy/ASR/zipformer/train.py @@ -77,6 +77,7 @@ from model import AsrModel from optim import Eden, ScaledAdam from scaling import ScheduledFloat from subsampling import Conv2dSubsampling +from text_normalization import remove_punc_to_upper from torch import Tensor from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP @@ -1188,16 +1189,7 @@ def run(rank, world_size, args): def normalize_text(c: Cut): - text = c.supervisions[0].text - if params.train_with_punctuation: - table = str.maketrans("’‘,。;?!():-《》、“”【】", "'',.;?!(): <>/\"\"[]") - text = text.translate(table) - else: - text = text.replace("‘", "'") - text = text.replace("’", "'") - tokens = set("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789'") - text_list = [x.upper() if x in tokens else " " for x in text] - text = " ".join("".join(text_list).split()).strip() + text = remove_punc_to_upper(c.supervisions[0].text) c.supervisions[0].text = text return c @@ -1245,7 +1237,9 @@ def run(rank, world_size, args): train_cuts += libriheavy.train_medium_cuts() if params.subset == "L": train_cuts += libriheavy.train_large_cuts() - train_cuts = train_cuts.map(normalize_text) + + if not params.train_with_punctuation: + train_cuts = train_cuts.map(normalize_text) train_cuts = train_cuts.filter(remove_short_and_long_utt) @@ -1261,7 +1255,9 @@ def run(rank, world_size, args): ) valid_cuts = libriheavy.dev_cuts() - valid_cuts = valid_cuts.map(normalize_text) + + if not params.train_with_punctuation: + valid_cuts = valid_cuts.map(normalize_text) valid_dl = libriheavy.valid_dataloaders(valid_cuts) diff --git a/requirements-ci.txt b/requirements-ci.txt index 2433e190b..ba8f1387f 100644 --- a/requirements-ci.txt +++ b/requirements-ci.txt @@ -17,6 +17,7 @@ six git+https://github.com/lhotse-speech/lhotse kaldilm==1.11 kaldialign==0.7.1 +num2words sentencepiece==0.1.96 tensorboard==2.8.0 typeguard==2.13.3 diff --git a/requirements.txt b/requirements.txt index a07f6b7c7..9ce2e0735 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,7 @@ kaldifst kaldilm kaldialign +num2words sentencepiece>=0.1.96 tensorboard typeguard