From ba78791e9d5c407cb96eaf63b37145408efb10d9 Mon Sep 17 00:00:00 2001 From: luomingshuang <739314837@qq.com> Date: Fri, 24 Jun 2022 16:33:44 +0800 Subject: [PATCH] add wer and cer for Chinese and English respectively --- README.md | 12 +-- egs/tal_csasr/ASR/RESULTS.md | 9 +- .../pruned_transducer_stateless5/decode.py | 95 ++++++++++++++++--- 3 files changed, 97 insertions(+), 19 deletions(-) diff --git a/README.md b/README.md index be00eac50..be922c191 100644 --- a/README.md +++ b/README.md @@ -293,12 +293,12 @@ We provide one model for this recipe: [Pruned stateless RNN-T: Conformer encoder #### Pruned stateless RNN-T: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss -The best CER(%) results: -| | dev | test | -|----------------------|------|------| -| greedy search | 7.30 | 7.39 | -| fast beam search | 7.15 | 7.22 | -| modified beam search | 7.18 | 7.26 | +The best results for Chinese CER(%) and English WER(%) respectivly (zh: Chinese, en: English): +|decoding-method | dev | dev_zh | dev_en | test | test_zh | test_en | +|--|--|--|--|--|--|--| +|greedy_search| 7.30 | 6.48 | 19.19 |7.39| 6.66 | 19.13| +|modified_beam_search| 7.15 | 6.35 | 18.95 | 7.22| 6.50 | 18.70 | +|fast_beam_search| 7.18 | 6.39| 18.90 | 7.27| 6.55 | 18.77| We provide a Colab notebook to run a pre-trained Pruned Transducer Stateless model: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1DmIx-NloI1CMU5GdZrlse7TRu4y3Dpf8?usp=sharing) diff --git a/egs/tal_csasr/ASR/RESULTS.md b/egs/tal_csasr/ASR/RESULTS.md index b711fa82b..ddff0ab61 100644 --- a/egs/tal_csasr/ASR/RESULTS.md +++ b/egs/tal_csasr/ASR/RESULTS.md @@ -15,11 +15,18 @@ The WERs are |fast_beam_search | 30 | 24 | 7.32 | 7.42| |greedy_search(use-averaged-model=True) | 30 | 24 | 7.30 | 7.39| |modified_beam_search(use-averaged-model=True) | 30 | 24 | 7.15 | 7.22| -|fast_beam_search(use-averaged-model=True) | 30 | 24 | 7.18 | 7.26| +|fast_beam_search(use-averaged-model=True) | 30 | 24 | 7.18 | 7.27| |greedy_search | 348000 | 30 | 7.46 | 7.54| |modified_beam_search | 348000 | 30 | 7.24 | 7.36| |fast_beam_search | 348000 | 30 | 7.25 | 7.39 | +The results (CER(%) and WER(%)) for Chinese CER and English WER respectivly (zh: Chinese, en: English): +|decoding-method | epoch(iter) | avg | dev | dev_zh | dev_en | test | test_zh | test_en | +|--|--|--|--|--|--|--|--|--| +|greedy_search(use-averaged-model=True) | 30 | 24 | 7.30 | 6.48 | 19.19 |7.39| 6.66 | 19.13| +|modified_beam_search(use-averaged-model=True) | 30 | 24 | 7.15 | 6.35 | 18.95 | 7.22| 6.50 | 18.70 | +|fast_beam_search(use-averaged-model=True) | 30 | 24 | 7.18 | 6.39| 18.90 | 7.27| 6.55 | 18.77| + The training command for reproducing is given below: ``` diff --git a/egs/tal_csasr/ASR/pruned_transducer_stateless5/decode.py b/egs/tal_csasr/ASR/pruned_transducer_stateless5/decode.py index 1ad6ed943..305729a99 100755 --- a/egs/tal_csasr/ASR/pruned_transducer_stateless5/decode.py +++ b/egs/tal_csasr/ASR/pruned_transducer_stateless5/decode.py @@ -272,7 +272,11 @@ def decode_one_batch( x=feature, x_lens=feature_lens ) hyps = [] + zh_hyps = [] + en_hyps = [] pattern = re.compile(r"([\u4e00-\u9fff])") + en_letter = "[\u0041-\u005a|\u0061-\u007a]+" # English letters + zh_char = "[\u4e00-\u9fa5]+" # Chinese chars if params.decoding_method == "fast_beam_search": hyp_tokens = fast_beam_search_one_best( model=model, @@ -287,10 +291,18 @@ def decode_one_batch( hyp = sp.decode([lexicon.token_table[idx] for idx in hyp_tokens[i]]) chars = pattern.split(hyp.upper()) chars_new = [] + zh_text = [] + en_text = [] for char in chars: if char != "": - chars_new.extend(char.strip().split(" ")) + tokens = char.strip().split(" ") + chars_new.extend(tokens) + for token in tokens: + zh_text.extend(re.findall(zh_char, token)) + en_text.extend(re.findall(en_letter, token)) hyps.append(chars_new) + zh_hyps.append(zh_text) + en_hyps.append(en_text) elif ( params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1 @@ -304,10 +316,18 @@ def decode_one_batch( hyp = sp.decode([lexicon.token_table[idx] for idx in hyp_tokens[i]]) chars = pattern.split(hyp.upper()) chars_new = [] + zh_text = [] + en_text = [] for char in chars: if char != "": - chars_new.extend(char.strip().split(" ")) + tokens = char.strip().split(" ") + chars_new.extend(tokens) + for token in tokens: + zh_text.extend(re.findall(zh_char, token)) + en_text.extend(re.findall(en_letter, token)) hyps.append(chars_new) + zh_hyps.append(zh_text) + en_hyps.append(en_text) elif params.decoding_method == "modified_beam_search": hyp_tokens = modified_beam_search( model=model, @@ -319,10 +339,18 @@ def decode_one_batch( hyp = sp.decode([lexicon.token_table[idx] for idx in hyp_tokens[i]]) chars = pattern.split(hyp.upper()) chars_new = [] + zh_text = [] + en_text = [] for char in chars: if char != "": - chars_new.extend(char.strip().split(" ")) + tokens = char.strip().split(" ") + chars_new.extend(tokens) + for token in tokens: + zh_text.extend(re.findall(zh_char, token)) + en_text.extend(re.findall(en_letter, token)) hyps.append(chars_new) + zh_hyps.append(zh_text) + en_hyps.append(en_text) else: batch_size = encoder_out.size(0) @@ -352,22 +380,30 @@ def decode_one_batch( ) chars = pattern.split(hyp.upper()) chars_new = [] + zh_text = [] + en_text = [] for char in chars: if char != "": - chars_new.extend(char.strip().split(" ")) + tokens = char.strip().split(" ") + chars_new.extend(tokens) + for token in tokens: + zh_text.extend(re.findall(zh_char, token)) + en_text.extend(re.findall(en_letter, token)) hyps.append(chars_new) + zh_hyps.append(zh_text) + en_hyps.append(en_text) if params.decoding_method == "greedy_search": - return {"greedy_search": hyps} + return {"greedy_search": (hyps, zh_hyps, en_hyps)} elif params.decoding_method == "fast_beam_search": return { ( f"beam_{params.beam}_" f"max_contexts_{params.max_contexts}_" f"max_states_{params.max_states}" - ): hyps + ): (hyps, zh_hyps, en_hyps) } else: - return {f"beam_size_{params.beam_size}": hyps} + return {f"beam_size_{params.beam_size}": (hyps, zh_hyps, en_hyps)} def decode_dataset( @@ -410,17 +446,30 @@ def decode_dataset( log_interval = 20 results = defaultdict(list) + zh_results = defaultdict(list) + en_results = defaultdict(list) pattern = re.compile(r"([\u4e00-\u9fff])") + en_letter = "[\u0041-\u005a|\u0061-\u007a]+" # English letters + zh_char = "[\u4e00-\u9fa5]+" # Chinese chars for batch_idx, batch in enumerate(dl): texts = batch["supervisions"]["text"] - # texts = [list(str(text).replace(" ", "")) for text in texts] + zh_texts = [] + en_texts = [] for i in range(len(texts)): text = texts[i] chars = pattern.split(text.upper()) chars_new = [] + zh_text = [] + en_text = [] for char in chars: if char != "": - chars_new.extend(char.strip().split(" ")) + tokens = char.strip().split(" ") + chars_new.extend(tokens) + for token in tokens: + zh_text.extend(re.findall(zh_char, token)) + en_text.extend(re.findall(en_letter, token)) + zh_texts.append(zh_text) + en_texts.append(en_text) texts[i] = chars_new hyps_dict = decode_one_batch( params=params, @@ -431,13 +480,25 @@ def decode_dataset( sp=sp, ) - for name, hyps in hyps_dict.items(): + for name, hyps_texts in hyps_dict.items(): this_batch = [] + this_batch_zh = [] + this_batch_en = [] + # print(hyps_texts) + hyps, zh_hyps, en_hyps = hyps_texts assert len(hyps) == len(texts) for hyp_words, ref_text in zip(hyps, texts): this_batch.append((ref_text, hyp_words)) + for hyp_words, ref_text in zip(zh_hyps, zh_texts): + this_batch_zh.append((ref_text, hyp_words)) + + for hyp_words, ref_text in zip(en_hyps, en_texts): + this_batch_en.append((ref_text, hyp_words)) + results[name].extend(this_batch) + zh_results[name + "_zh"].extend(this_batch_zh) + en_results[name + "_en"].extend(this_batch_en) num_cuts += len(texts) @@ -447,7 +508,7 @@ def decode_dataset( logging.info( f"batch {batch_str}, cuts processed until now is {num_cuts}" ) - return results + return results, zh_results, en_results def save_results( @@ -663,7 +724,7 @@ def main(): test_dl = [dev_dl, test_dl] for test_set, test_dl in zip(test_sets, test_dl): - results_dict = decode_dataset( + results_dict, zh_results_dict, en_results_dict = decode_dataset( dl=test_dl, params=params, model=model, @@ -676,6 +737,16 @@ def main(): test_set_name=test_set, results_dict=results_dict, ) + save_results( + params=params, + test_set_name=test_set, + results_dict=zh_results_dict, + ) + save_results( + params=params, + test_set_name=test_set, + results_dict=en_results_dict, + ) logging.info("Done!")