mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-19 05:54:20 +00:00
add wer and cer for Chinese and English respectively
This commit is contained in:
parent
c1334d4da6
commit
ba78791e9d
12
README.md
12
README.md
@ -293,12 +293,12 @@ We provide one model for this recipe: [Pruned stateless RNN-T: Conformer encoder
|
||||
|
||||
#### Pruned stateless RNN-T: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss
|
||||
|
||||
The best CER(%) results:
|
||||
| | dev | test |
|
||||
|----------------------|------|------|
|
||||
| greedy search | 7.30 | 7.39 |
|
||||
| fast beam search | 7.15 | 7.22 |
|
||||
| modified beam search | 7.18 | 7.26 |
|
||||
The best results for Chinese CER(%) and English WER(%) respectivly (zh: Chinese, en: English):
|
||||
|decoding-method | dev | dev_zh | dev_en | test | test_zh | test_en |
|
||||
|--|--|--|--|--|--|--|
|
||||
|greedy_search| 7.30 | 6.48 | 19.19 |7.39| 6.66 | 19.13|
|
||||
|modified_beam_search| 7.15 | 6.35 | 18.95 | 7.22| 6.50 | 18.70 |
|
||||
|fast_beam_search| 7.18 | 6.39| 18.90 | 7.27| 6.55 | 18.77|
|
||||
|
||||
We provide a Colab notebook to run a pre-trained Pruned Transducer Stateless model: [](https://colab.research.google.com/drive/1DmIx-NloI1CMU5GdZrlse7TRu4y3Dpf8?usp=sharing)
|
||||
|
||||
|
@ -15,11 +15,18 @@ The WERs are
|
||||
|fast_beam_search | 30 | 24 | 7.32 | 7.42|
|
||||
|greedy_search(use-averaged-model=True) | 30 | 24 | 7.30 | 7.39|
|
||||
|modified_beam_search(use-averaged-model=True) | 30 | 24 | 7.15 | 7.22|
|
||||
|fast_beam_search(use-averaged-model=True) | 30 | 24 | 7.18 | 7.26|
|
||||
|fast_beam_search(use-averaged-model=True) | 30 | 24 | 7.18 | 7.27|
|
||||
|greedy_search | 348000 | 30 | 7.46 | 7.54|
|
||||
|modified_beam_search | 348000 | 30 | 7.24 | 7.36|
|
||||
|fast_beam_search | 348000 | 30 | 7.25 | 7.39 |
|
||||
|
||||
The results (CER(%) and WER(%)) for Chinese CER and English WER respectivly (zh: Chinese, en: English):
|
||||
|decoding-method | epoch(iter) | avg | dev | dev_zh | dev_en | test | test_zh | test_en |
|
||||
|--|--|--|--|--|--|--|--|--|
|
||||
|greedy_search(use-averaged-model=True) | 30 | 24 | 7.30 | 6.48 | 19.19 |7.39| 6.66 | 19.13|
|
||||
|modified_beam_search(use-averaged-model=True) | 30 | 24 | 7.15 | 6.35 | 18.95 | 7.22| 6.50 | 18.70 |
|
||||
|fast_beam_search(use-averaged-model=True) | 30 | 24 | 7.18 | 6.39| 18.90 | 7.27| 6.55 | 18.77|
|
||||
|
||||
The training command for reproducing is given below:
|
||||
|
||||
```
|
||||
|
@ -272,7 +272,11 @@ def decode_one_batch(
|
||||
x=feature, x_lens=feature_lens
|
||||
)
|
||||
hyps = []
|
||||
zh_hyps = []
|
||||
en_hyps = []
|
||||
pattern = re.compile(r"([\u4e00-\u9fff])")
|
||||
en_letter = "[\u0041-\u005a|\u0061-\u007a]+" # English letters
|
||||
zh_char = "[\u4e00-\u9fa5]+" # Chinese chars
|
||||
if params.decoding_method == "fast_beam_search":
|
||||
hyp_tokens = fast_beam_search_one_best(
|
||||
model=model,
|
||||
@ -287,10 +291,18 @@ def decode_one_batch(
|
||||
hyp = sp.decode([lexicon.token_table[idx] for idx in hyp_tokens[i]])
|
||||
chars = pattern.split(hyp.upper())
|
||||
chars_new = []
|
||||
zh_text = []
|
||||
en_text = []
|
||||
for char in chars:
|
||||
if char != "":
|
||||
chars_new.extend(char.strip().split(" "))
|
||||
tokens = char.strip().split(" ")
|
||||
chars_new.extend(tokens)
|
||||
for token in tokens:
|
||||
zh_text.extend(re.findall(zh_char, token))
|
||||
en_text.extend(re.findall(en_letter, token))
|
||||
hyps.append(chars_new)
|
||||
zh_hyps.append(zh_text)
|
||||
en_hyps.append(en_text)
|
||||
elif (
|
||||
params.decoding_method == "greedy_search"
|
||||
and params.max_sym_per_frame == 1
|
||||
@ -304,10 +316,18 @@ def decode_one_batch(
|
||||
hyp = sp.decode([lexicon.token_table[idx] for idx in hyp_tokens[i]])
|
||||
chars = pattern.split(hyp.upper())
|
||||
chars_new = []
|
||||
zh_text = []
|
||||
en_text = []
|
||||
for char in chars:
|
||||
if char != "":
|
||||
chars_new.extend(char.strip().split(" "))
|
||||
tokens = char.strip().split(" ")
|
||||
chars_new.extend(tokens)
|
||||
for token in tokens:
|
||||
zh_text.extend(re.findall(zh_char, token))
|
||||
en_text.extend(re.findall(en_letter, token))
|
||||
hyps.append(chars_new)
|
||||
zh_hyps.append(zh_text)
|
||||
en_hyps.append(en_text)
|
||||
elif params.decoding_method == "modified_beam_search":
|
||||
hyp_tokens = modified_beam_search(
|
||||
model=model,
|
||||
@ -319,10 +339,18 @@ def decode_one_batch(
|
||||
hyp = sp.decode([lexicon.token_table[idx] for idx in hyp_tokens[i]])
|
||||
chars = pattern.split(hyp.upper())
|
||||
chars_new = []
|
||||
zh_text = []
|
||||
en_text = []
|
||||
for char in chars:
|
||||
if char != "":
|
||||
chars_new.extend(char.strip().split(" "))
|
||||
tokens = char.strip().split(" ")
|
||||
chars_new.extend(tokens)
|
||||
for token in tokens:
|
||||
zh_text.extend(re.findall(zh_char, token))
|
||||
en_text.extend(re.findall(en_letter, token))
|
||||
hyps.append(chars_new)
|
||||
zh_hyps.append(zh_text)
|
||||
en_hyps.append(en_text)
|
||||
else:
|
||||
batch_size = encoder_out.size(0)
|
||||
|
||||
@ -352,22 +380,30 @@ def decode_one_batch(
|
||||
)
|
||||
chars = pattern.split(hyp.upper())
|
||||
chars_new = []
|
||||
zh_text = []
|
||||
en_text = []
|
||||
for char in chars:
|
||||
if char != "":
|
||||
chars_new.extend(char.strip().split(" "))
|
||||
tokens = char.strip().split(" ")
|
||||
chars_new.extend(tokens)
|
||||
for token in tokens:
|
||||
zh_text.extend(re.findall(zh_char, token))
|
||||
en_text.extend(re.findall(en_letter, token))
|
||||
hyps.append(chars_new)
|
||||
zh_hyps.append(zh_text)
|
||||
en_hyps.append(en_text)
|
||||
if params.decoding_method == "greedy_search":
|
||||
return {"greedy_search": hyps}
|
||||
return {"greedy_search": (hyps, zh_hyps, en_hyps)}
|
||||
elif params.decoding_method == "fast_beam_search":
|
||||
return {
|
||||
(
|
||||
f"beam_{params.beam}_"
|
||||
f"max_contexts_{params.max_contexts}_"
|
||||
f"max_states_{params.max_states}"
|
||||
): hyps
|
||||
): (hyps, zh_hyps, en_hyps)
|
||||
}
|
||||
else:
|
||||
return {f"beam_size_{params.beam_size}": hyps}
|
||||
return {f"beam_size_{params.beam_size}": (hyps, zh_hyps, en_hyps)}
|
||||
|
||||
|
||||
def decode_dataset(
|
||||
@ -410,17 +446,30 @@ def decode_dataset(
|
||||
log_interval = 20
|
||||
|
||||
results = defaultdict(list)
|
||||
zh_results = defaultdict(list)
|
||||
en_results = defaultdict(list)
|
||||
pattern = re.compile(r"([\u4e00-\u9fff])")
|
||||
en_letter = "[\u0041-\u005a|\u0061-\u007a]+" # English letters
|
||||
zh_char = "[\u4e00-\u9fa5]+" # Chinese chars
|
||||
for batch_idx, batch in enumerate(dl):
|
||||
texts = batch["supervisions"]["text"]
|
||||
# texts = [list(str(text).replace(" ", "")) for text in texts]
|
||||
zh_texts = []
|
||||
en_texts = []
|
||||
for i in range(len(texts)):
|
||||
text = texts[i]
|
||||
chars = pattern.split(text.upper())
|
||||
chars_new = []
|
||||
zh_text = []
|
||||
en_text = []
|
||||
for char in chars:
|
||||
if char != "":
|
||||
chars_new.extend(char.strip().split(" "))
|
||||
tokens = char.strip().split(" ")
|
||||
chars_new.extend(tokens)
|
||||
for token in tokens:
|
||||
zh_text.extend(re.findall(zh_char, token))
|
||||
en_text.extend(re.findall(en_letter, token))
|
||||
zh_texts.append(zh_text)
|
||||
en_texts.append(en_text)
|
||||
texts[i] = chars_new
|
||||
hyps_dict = decode_one_batch(
|
||||
params=params,
|
||||
@ -431,13 +480,25 @@ def decode_dataset(
|
||||
sp=sp,
|
||||
)
|
||||
|
||||
for name, hyps in hyps_dict.items():
|
||||
for name, hyps_texts in hyps_dict.items():
|
||||
this_batch = []
|
||||
this_batch_zh = []
|
||||
this_batch_en = []
|
||||
# print(hyps_texts)
|
||||
hyps, zh_hyps, en_hyps = hyps_texts
|
||||
assert len(hyps) == len(texts)
|
||||
for hyp_words, ref_text in zip(hyps, texts):
|
||||
this_batch.append((ref_text, hyp_words))
|
||||
|
||||
for hyp_words, ref_text in zip(zh_hyps, zh_texts):
|
||||
this_batch_zh.append((ref_text, hyp_words))
|
||||
|
||||
for hyp_words, ref_text in zip(en_hyps, en_texts):
|
||||
this_batch_en.append((ref_text, hyp_words))
|
||||
|
||||
results[name].extend(this_batch)
|
||||
zh_results[name + "_zh"].extend(this_batch_zh)
|
||||
en_results[name + "_en"].extend(this_batch_en)
|
||||
|
||||
num_cuts += len(texts)
|
||||
|
||||
@ -447,7 +508,7 @@ def decode_dataset(
|
||||
logging.info(
|
||||
f"batch {batch_str}, cuts processed until now is {num_cuts}"
|
||||
)
|
||||
return results
|
||||
return results, zh_results, en_results
|
||||
|
||||
|
||||
def save_results(
|
||||
@ -663,7 +724,7 @@ def main():
|
||||
test_dl = [dev_dl, test_dl]
|
||||
|
||||
for test_set, test_dl in zip(test_sets, test_dl):
|
||||
results_dict = decode_dataset(
|
||||
results_dict, zh_results_dict, en_results_dict = decode_dataset(
|
||||
dl=test_dl,
|
||||
params=params,
|
||||
model=model,
|
||||
@ -676,6 +737,16 @@ def main():
|
||||
test_set_name=test_set,
|
||||
results_dict=results_dict,
|
||||
)
|
||||
save_results(
|
||||
params=params,
|
||||
test_set_name=test_set,
|
||||
results_dict=zh_results_dict,
|
||||
)
|
||||
save_results(
|
||||
params=params,
|
||||
test_set_name=test_set,
|
||||
results_dict=en_results_dict,
|
||||
)
|
||||
|
||||
logging.info("Done!")
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user