add wer and cer for Chinese and English respectively

This commit is contained in:
luomingshuang 2022-06-24 16:33:44 +08:00
parent c1334d4da6
commit ba78791e9d
3 changed files with 97 additions and 19 deletions

View File

@ -293,12 +293,12 @@ We provide one model for this recipe: [Pruned stateless RNN-T: Conformer encoder
#### Pruned stateless RNN-T: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss #### Pruned stateless RNN-T: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss
The best CER(%) results: The best results for Chinese CER(%) and English WER(%) respectivly (zh: Chinese, en: English):
| | dev | test | |decoding-method | dev | dev_zh | dev_en | test | test_zh | test_en |
|----------------------|------|------| |--|--|--|--|--|--|--|
| greedy search | 7.30 | 7.39 | |greedy_search| 7.30 | 6.48 | 19.19 |7.39| 6.66 | 19.13|
| fast beam search | 7.15 | 7.22 | |modified_beam_search| 7.15 | 6.35 | 18.95 | 7.22| 6.50 | 18.70 |
| modified beam search | 7.18 | 7.26 | |fast_beam_search| 7.18 | 6.39| 18.90 | 7.27| 6.55 | 18.77|
We provide a Colab notebook to run a pre-trained Pruned Transducer Stateless model: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1DmIx-NloI1CMU5GdZrlse7TRu4y3Dpf8?usp=sharing) We provide a Colab notebook to run a pre-trained Pruned Transducer Stateless model: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1DmIx-NloI1CMU5GdZrlse7TRu4y3Dpf8?usp=sharing)

View File

@ -15,11 +15,18 @@ The WERs are
|fast_beam_search | 30 | 24 | 7.32 | 7.42| |fast_beam_search | 30 | 24 | 7.32 | 7.42|
|greedy_search(use-averaged-model=True) | 30 | 24 | 7.30 | 7.39| |greedy_search(use-averaged-model=True) | 30 | 24 | 7.30 | 7.39|
|modified_beam_search(use-averaged-model=True) | 30 | 24 | 7.15 | 7.22| |modified_beam_search(use-averaged-model=True) | 30 | 24 | 7.15 | 7.22|
|fast_beam_search(use-averaged-model=True) | 30 | 24 | 7.18 | 7.26| |fast_beam_search(use-averaged-model=True) | 30 | 24 | 7.18 | 7.27|
|greedy_search | 348000 | 30 | 7.46 | 7.54| |greedy_search | 348000 | 30 | 7.46 | 7.54|
|modified_beam_search | 348000 | 30 | 7.24 | 7.36| |modified_beam_search | 348000 | 30 | 7.24 | 7.36|
|fast_beam_search | 348000 | 30 | 7.25 | 7.39 | |fast_beam_search | 348000 | 30 | 7.25 | 7.39 |
The results (CER(%) and WER(%)) for Chinese CER and English WER respectivly (zh: Chinese, en: English):
|decoding-method | epoch(iter) | avg | dev | dev_zh | dev_en | test | test_zh | test_en |
|--|--|--|--|--|--|--|--|--|
|greedy_search(use-averaged-model=True) | 30 | 24 | 7.30 | 6.48 | 19.19 |7.39| 6.66 | 19.13|
|modified_beam_search(use-averaged-model=True) | 30 | 24 | 7.15 | 6.35 | 18.95 | 7.22| 6.50 | 18.70 |
|fast_beam_search(use-averaged-model=True) | 30 | 24 | 7.18 | 6.39| 18.90 | 7.27| 6.55 | 18.77|
The training command for reproducing is given below: The training command for reproducing is given below:
``` ```

View File

@ -272,7 +272,11 @@ def decode_one_batch(
x=feature, x_lens=feature_lens x=feature, x_lens=feature_lens
) )
hyps = [] hyps = []
zh_hyps = []
en_hyps = []
pattern = re.compile(r"([\u4e00-\u9fff])") pattern = re.compile(r"([\u4e00-\u9fff])")
en_letter = "[\u0041-\u005a|\u0061-\u007a]+" # English letters
zh_char = "[\u4e00-\u9fa5]+" # Chinese chars
if params.decoding_method == "fast_beam_search": if params.decoding_method == "fast_beam_search":
hyp_tokens = fast_beam_search_one_best( hyp_tokens = fast_beam_search_one_best(
model=model, model=model,
@ -287,10 +291,18 @@ def decode_one_batch(
hyp = sp.decode([lexicon.token_table[idx] for idx in hyp_tokens[i]]) hyp = sp.decode([lexicon.token_table[idx] for idx in hyp_tokens[i]])
chars = pattern.split(hyp.upper()) chars = pattern.split(hyp.upper())
chars_new = [] chars_new = []
zh_text = []
en_text = []
for char in chars: for char in chars:
if char != "": if char != "":
chars_new.extend(char.strip().split(" ")) tokens = char.strip().split(" ")
chars_new.extend(tokens)
for token in tokens:
zh_text.extend(re.findall(zh_char, token))
en_text.extend(re.findall(en_letter, token))
hyps.append(chars_new) hyps.append(chars_new)
zh_hyps.append(zh_text)
en_hyps.append(en_text)
elif ( elif (
params.decoding_method == "greedy_search" params.decoding_method == "greedy_search"
and params.max_sym_per_frame == 1 and params.max_sym_per_frame == 1
@ -304,10 +316,18 @@ def decode_one_batch(
hyp = sp.decode([lexicon.token_table[idx] for idx in hyp_tokens[i]]) hyp = sp.decode([lexicon.token_table[idx] for idx in hyp_tokens[i]])
chars = pattern.split(hyp.upper()) chars = pattern.split(hyp.upper())
chars_new = [] chars_new = []
zh_text = []
en_text = []
for char in chars: for char in chars:
if char != "": if char != "":
chars_new.extend(char.strip().split(" ")) tokens = char.strip().split(" ")
chars_new.extend(tokens)
for token in tokens:
zh_text.extend(re.findall(zh_char, token))
en_text.extend(re.findall(en_letter, token))
hyps.append(chars_new) hyps.append(chars_new)
zh_hyps.append(zh_text)
en_hyps.append(en_text)
elif params.decoding_method == "modified_beam_search": elif params.decoding_method == "modified_beam_search":
hyp_tokens = modified_beam_search( hyp_tokens = modified_beam_search(
model=model, model=model,
@ -319,10 +339,18 @@ def decode_one_batch(
hyp = sp.decode([lexicon.token_table[idx] for idx in hyp_tokens[i]]) hyp = sp.decode([lexicon.token_table[idx] for idx in hyp_tokens[i]])
chars = pattern.split(hyp.upper()) chars = pattern.split(hyp.upper())
chars_new = [] chars_new = []
zh_text = []
en_text = []
for char in chars: for char in chars:
if char != "": if char != "":
chars_new.extend(char.strip().split(" ")) tokens = char.strip().split(" ")
chars_new.extend(tokens)
for token in tokens:
zh_text.extend(re.findall(zh_char, token))
en_text.extend(re.findall(en_letter, token))
hyps.append(chars_new) hyps.append(chars_new)
zh_hyps.append(zh_text)
en_hyps.append(en_text)
else: else:
batch_size = encoder_out.size(0) batch_size = encoder_out.size(0)
@ -352,22 +380,30 @@ def decode_one_batch(
) )
chars = pattern.split(hyp.upper()) chars = pattern.split(hyp.upper())
chars_new = [] chars_new = []
zh_text = []
en_text = []
for char in chars: for char in chars:
if char != "": if char != "":
chars_new.extend(char.strip().split(" ")) tokens = char.strip().split(" ")
chars_new.extend(tokens)
for token in tokens:
zh_text.extend(re.findall(zh_char, token))
en_text.extend(re.findall(en_letter, token))
hyps.append(chars_new) hyps.append(chars_new)
zh_hyps.append(zh_text)
en_hyps.append(en_text)
if params.decoding_method == "greedy_search": if params.decoding_method == "greedy_search":
return {"greedy_search": hyps} return {"greedy_search": (hyps, zh_hyps, en_hyps)}
elif params.decoding_method == "fast_beam_search": elif params.decoding_method == "fast_beam_search":
return { return {
( (
f"beam_{params.beam}_" f"beam_{params.beam}_"
f"max_contexts_{params.max_contexts}_" f"max_contexts_{params.max_contexts}_"
f"max_states_{params.max_states}" f"max_states_{params.max_states}"
): hyps ): (hyps, zh_hyps, en_hyps)
} }
else: else:
return {f"beam_size_{params.beam_size}": hyps} return {f"beam_size_{params.beam_size}": (hyps, zh_hyps, en_hyps)}
def decode_dataset( def decode_dataset(
@ -410,17 +446,30 @@ def decode_dataset(
log_interval = 20 log_interval = 20
results = defaultdict(list) results = defaultdict(list)
zh_results = defaultdict(list)
en_results = defaultdict(list)
pattern = re.compile(r"([\u4e00-\u9fff])") pattern = re.compile(r"([\u4e00-\u9fff])")
en_letter = "[\u0041-\u005a|\u0061-\u007a]+" # English letters
zh_char = "[\u4e00-\u9fa5]+" # Chinese chars
for batch_idx, batch in enumerate(dl): for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"] texts = batch["supervisions"]["text"]
# texts = [list(str(text).replace(" ", "")) for text in texts] zh_texts = []
en_texts = []
for i in range(len(texts)): for i in range(len(texts)):
text = texts[i] text = texts[i]
chars = pattern.split(text.upper()) chars = pattern.split(text.upper())
chars_new = [] chars_new = []
zh_text = []
en_text = []
for char in chars: for char in chars:
if char != "": if char != "":
chars_new.extend(char.strip().split(" ")) tokens = char.strip().split(" ")
chars_new.extend(tokens)
for token in tokens:
zh_text.extend(re.findall(zh_char, token))
en_text.extend(re.findall(en_letter, token))
zh_texts.append(zh_text)
en_texts.append(en_text)
texts[i] = chars_new texts[i] = chars_new
hyps_dict = decode_one_batch( hyps_dict = decode_one_batch(
params=params, params=params,
@ -431,13 +480,25 @@ def decode_dataset(
sp=sp, sp=sp,
) )
for name, hyps in hyps_dict.items(): for name, hyps_texts in hyps_dict.items():
this_batch = [] this_batch = []
this_batch_zh = []
this_batch_en = []
# print(hyps_texts)
hyps, zh_hyps, en_hyps = hyps_texts
assert len(hyps) == len(texts) assert len(hyps) == len(texts)
for hyp_words, ref_text in zip(hyps, texts): for hyp_words, ref_text in zip(hyps, texts):
this_batch.append((ref_text, hyp_words)) this_batch.append((ref_text, hyp_words))
for hyp_words, ref_text in zip(zh_hyps, zh_texts):
this_batch_zh.append((ref_text, hyp_words))
for hyp_words, ref_text in zip(en_hyps, en_texts):
this_batch_en.append((ref_text, hyp_words))
results[name].extend(this_batch) results[name].extend(this_batch)
zh_results[name + "_zh"].extend(this_batch_zh)
en_results[name + "_en"].extend(this_batch_en)
num_cuts += len(texts) num_cuts += len(texts)
@ -447,7 +508,7 @@ def decode_dataset(
logging.info( logging.info(
f"batch {batch_str}, cuts processed until now is {num_cuts}" f"batch {batch_str}, cuts processed until now is {num_cuts}"
) )
return results return results, zh_results, en_results
def save_results( def save_results(
@ -663,7 +724,7 @@ def main():
test_dl = [dev_dl, test_dl] test_dl = [dev_dl, test_dl]
for test_set, test_dl in zip(test_sets, test_dl): for test_set, test_dl in zip(test_sets, test_dl):
results_dict = decode_dataset( results_dict, zh_results_dict, en_results_dict = decode_dataset(
dl=test_dl, dl=test_dl,
params=params, params=params,
model=model, model=model,
@ -676,6 +737,16 @@ def main():
test_set_name=test_set, test_set_name=test_set,
results_dict=results_dict, results_dict=results_dict,
) )
save_results(
params=params,
test_set_name=test_set,
results_dict=zh_results_dict,
)
save_results(
params=params,
test_set_name=test_set,
results_dict=en_results_dict,
)
logging.info("Done!") logging.info("Done!")