mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-18 21:44:18 +00:00
add wer and cer for Chinese and English respectively
This commit is contained in:
parent
c1334d4da6
commit
ba78791e9d
12
README.md
12
README.md
@ -293,12 +293,12 @@ We provide one model for this recipe: [Pruned stateless RNN-T: Conformer encoder
|
|||||||
|
|
||||||
#### Pruned stateless RNN-T: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss
|
#### Pruned stateless RNN-T: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss
|
||||||
|
|
||||||
The best CER(%) results:
|
The best results for Chinese CER(%) and English WER(%) respectivly (zh: Chinese, en: English):
|
||||||
| | dev | test |
|
|decoding-method | dev | dev_zh | dev_en | test | test_zh | test_en |
|
||||||
|----------------------|------|------|
|
|--|--|--|--|--|--|--|
|
||||||
| greedy search | 7.30 | 7.39 |
|
|greedy_search| 7.30 | 6.48 | 19.19 |7.39| 6.66 | 19.13|
|
||||||
| fast beam search | 7.15 | 7.22 |
|
|modified_beam_search| 7.15 | 6.35 | 18.95 | 7.22| 6.50 | 18.70 |
|
||||||
| modified beam search | 7.18 | 7.26 |
|
|fast_beam_search| 7.18 | 6.39| 18.90 | 7.27| 6.55 | 18.77|
|
||||||
|
|
||||||
We provide a Colab notebook to run a pre-trained Pruned Transducer Stateless model: [](https://colab.research.google.com/drive/1DmIx-NloI1CMU5GdZrlse7TRu4y3Dpf8?usp=sharing)
|
We provide a Colab notebook to run a pre-trained Pruned Transducer Stateless model: [](https://colab.research.google.com/drive/1DmIx-NloI1CMU5GdZrlse7TRu4y3Dpf8?usp=sharing)
|
||||||
|
|
||||||
|
@ -15,11 +15,18 @@ The WERs are
|
|||||||
|fast_beam_search | 30 | 24 | 7.32 | 7.42|
|
|fast_beam_search | 30 | 24 | 7.32 | 7.42|
|
||||||
|greedy_search(use-averaged-model=True) | 30 | 24 | 7.30 | 7.39|
|
|greedy_search(use-averaged-model=True) | 30 | 24 | 7.30 | 7.39|
|
||||||
|modified_beam_search(use-averaged-model=True) | 30 | 24 | 7.15 | 7.22|
|
|modified_beam_search(use-averaged-model=True) | 30 | 24 | 7.15 | 7.22|
|
||||||
|fast_beam_search(use-averaged-model=True) | 30 | 24 | 7.18 | 7.26|
|
|fast_beam_search(use-averaged-model=True) | 30 | 24 | 7.18 | 7.27|
|
||||||
|greedy_search | 348000 | 30 | 7.46 | 7.54|
|
|greedy_search | 348000 | 30 | 7.46 | 7.54|
|
||||||
|modified_beam_search | 348000 | 30 | 7.24 | 7.36|
|
|modified_beam_search | 348000 | 30 | 7.24 | 7.36|
|
||||||
|fast_beam_search | 348000 | 30 | 7.25 | 7.39 |
|
|fast_beam_search | 348000 | 30 | 7.25 | 7.39 |
|
||||||
|
|
||||||
|
The results (CER(%) and WER(%)) for Chinese CER and English WER respectivly (zh: Chinese, en: English):
|
||||||
|
|decoding-method | epoch(iter) | avg | dev | dev_zh | dev_en | test | test_zh | test_en |
|
||||||
|
|--|--|--|--|--|--|--|--|--|
|
||||||
|
|greedy_search(use-averaged-model=True) | 30 | 24 | 7.30 | 6.48 | 19.19 |7.39| 6.66 | 19.13|
|
||||||
|
|modified_beam_search(use-averaged-model=True) | 30 | 24 | 7.15 | 6.35 | 18.95 | 7.22| 6.50 | 18.70 |
|
||||||
|
|fast_beam_search(use-averaged-model=True) | 30 | 24 | 7.18 | 6.39| 18.90 | 7.27| 6.55 | 18.77|
|
||||||
|
|
||||||
The training command for reproducing is given below:
|
The training command for reproducing is given below:
|
||||||
|
|
||||||
```
|
```
|
||||||
|
@ -272,7 +272,11 @@ def decode_one_batch(
|
|||||||
x=feature, x_lens=feature_lens
|
x=feature, x_lens=feature_lens
|
||||||
)
|
)
|
||||||
hyps = []
|
hyps = []
|
||||||
|
zh_hyps = []
|
||||||
|
en_hyps = []
|
||||||
pattern = re.compile(r"([\u4e00-\u9fff])")
|
pattern = re.compile(r"([\u4e00-\u9fff])")
|
||||||
|
en_letter = "[\u0041-\u005a|\u0061-\u007a]+" # English letters
|
||||||
|
zh_char = "[\u4e00-\u9fa5]+" # Chinese chars
|
||||||
if params.decoding_method == "fast_beam_search":
|
if params.decoding_method == "fast_beam_search":
|
||||||
hyp_tokens = fast_beam_search_one_best(
|
hyp_tokens = fast_beam_search_one_best(
|
||||||
model=model,
|
model=model,
|
||||||
@ -287,10 +291,18 @@ def decode_one_batch(
|
|||||||
hyp = sp.decode([lexicon.token_table[idx] for idx in hyp_tokens[i]])
|
hyp = sp.decode([lexicon.token_table[idx] for idx in hyp_tokens[i]])
|
||||||
chars = pattern.split(hyp.upper())
|
chars = pattern.split(hyp.upper())
|
||||||
chars_new = []
|
chars_new = []
|
||||||
|
zh_text = []
|
||||||
|
en_text = []
|
||||||
for char in chars:
|
for char in chars:
|
||||||
if char != "":
|
if char != "":
|
||||||
chars_new.extend(char.strip().split(" "))
|
tokens = char.strip().split(" ")
|
||||||
|
chars_new.extend(tokens)
|
||||||
|
for token in tokens:
|
||||||
|
zh_text.extend(re.findall(zh_char, token))
|
||||||
|
en_text.extend(re.findall(en_letter, token))
|
||||||
hyps.append(chars_new)
|
hyps.append(chars_new)
|
||||||
|
zh_hyps.append(zh_text)
|
||||||
|
en_hyps.append(en_text)
|
||||||
elif (
|
elif (
|
||||||
params.decoding_method == "greedy_search"
|
params.decoding_method == "greedy_search"
|
||||||
and params.max_sym_per_frame == 1
|
and params.max_sym_per_frame == 1
|
||||||
@ -304,10 +316,18 @@ def decode_one_batch(
|
|||||||
hyp = sp.decode([lexicon.token_table[idx] for idx in hyp_tokens[i]])
|
hyp = sp.decode([lexicon.token_table[idx] for idx in hyp_tokens[i]])
|
||||||
chars = pattern.split(hyp.upper())
|
chars = pattern.split(hyp.upper())
|
||||||
chars_new = []
|
chars_new = []
|
||||||
|
zh_text = []
|
||||||
|
en_text = []
|
||||||
for char in chars:
|
for char in chars:
|
||||||
if char != "":
|
if char != "":
|
||||||
chars_new.extend(char.strip().split(" "))
|
tokens = char.strip().split(" ")
|
||||||
|
chars_new.extend(tokens)
|
||||||
|
for token in tokens:
|
||||||
|
zh_text.extend(re.findall(zh_char, token))
|
||||||
|
en_text.extend(re.findall(en_letter, token))
|
||||||
hyps.append(chars_new)
|
hyps.append(chars_new)
|
||||||
|
zh_hyps.append(zh_text)
|
||||||
|
en_hyps.append(en_text)
|
||||||
elif params.decoding_method == "modified_beam_search":
|
elif params.decoding_method == "modified_beam_search":
|
||||||
hyp_tokens = modified_beam_search(
|
hyp_tokens = modified_beam_search(
|
||||||
model=model,
|
model=model,
|
||||||
@ -319,10 +339,18 @@ def decode_one_batch(
|
|||||||
hyp = sp.decode([lexicon.token_table[idx] for idx in hyp_tokens[i]])
|
hyp = sp.decode([lexicon.token_table[idx] for idx in hyp_tokens[i]])
|
||||||
chars = pattern.split(hyp.upper())
|
chars = pattern.split(hyp.upper())
|
||||||
chars_new = []
|
chars_new = []
|
||||||
|
zh_text = []
|
||||||
|
en_text = []
|
||||||
for char in chars:
|
for char in chars:
|
||||||
if char != "":
|
if char != "":
|
||||||
chars_new.extend(char.strip().split(" "))
|
tokens = char.strip().split(" ")
|
||||||
|
chars_new.extend(tokens)
|
||||||
|
for token in tokens:
|
||||||
|
zh_text.extend(re.findall(zh_char, token))
|
||||||
|
en_text.extend(re.findall(en_letter, token))
|
||||||
hyps.append(chars_new)
|
hyps.append(chars_new)
|
||||||
|
zh_hyps.append(zh_text)
|
||||||
|
en_hyps.append(en_text)
|
||||||
else:
|
else:
|
||||||
batch_size = encoder_out.size(0)
|
batch_size = encoder_out.size(0)
|
||||||
|
|
||||||
@ -352,22 +380,30 @@ def decode_one_batch(
|
|||||||
)
|
)
|
||||||
chars = pattern.split(hyp.upper())
|
chars = pattern.split(hyp.upper())
|
||||||
chars_new = []
|
chars_new = []
|
||||||
|
zh_text = []
|
||||||
|
en_text = []
|
||||||
for char in chars:
|
for char in chars:
|
||||||
if char != "":
|
if char != "":
|
||||||
chars_new.extend(char.strip().split(" "))
|
tokens = char.strip().split(" ")
|
||||||
|
chars_new.extend(tokens)
|
||||||
|
for token in tokens:
|
||||||
|
zh_text.extend(re.findall(zh_char, token))
|
||||||
|
en_text.extend(re.findall(en_letter, token))
|
||||||
hyps.append(chars_new)
|
hyps.append(chars_new)
|
||||||
|
zh_hyps.append(zh_text)
|
||||||
|
en_hyps.append(en_text)
|
||||||
if params.decoding_method == "greedy_search":
|
if params.decoding_method == "greedy_search":
|
||||||
return {"greedy_search": hyps}
|
return {"greedy_search": (hyps, zh_hyps, en_hyps)}
|
||||||
elif params.decoding_method == "fast_beam_search":
|
elif params.decoding_method == "fast_beam_search":
|
||||||
return {
|
return {
|
||||||
(
|
(
|
||||||
f"beam_{params.beam}_"
|
f"beam_{params.beam}_"
|
||||||
f"max_contexts_{params.max_contexts}_"
|
f"max_contexts_{params.max_contexts}_"
|
||||||
f"max_states_{params.max_states}"
|
f"max_states_{params.max_states}"
|
||||||
): hyps
|
): (hyps, zh_hyps, en_hyps)
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
return {f"beam_size_{params.beam_size}": hyps}
|
return {f"beam_size_{params.beam_size}": (hyps, zh_hyps, en_hyps)}
|
||||||
|
|
||||||
|
|
||||||
def decode_dataset(
|
def decode_dataset(
|
||||||
@ -410,17 +446,30 @@ def decode_dataset(
|
|||||||
log_interval = 20
|
log_interval = 20
|
||||||
|
|
||||||
results = defaultdict(list)
|
results = defaultdict(list)
|
||||||
|
zh_results = defaultdict(list)
|
||||||
|
en_results = defaultdict(list)
|
||||||
pattern = re.compile(r"([\u4e00-\u9fff])")
|
pattern = re.compile(r"([\u4e00-\u9fff])")
|
||||||
|
en_letter = "[\u0041-\u005a|\u0061-\u007a]+" # English letters
|
||||||
|
zh_char = "[\u4e00-\u9fa5]+" # Chinese chars
|
||||||
for batch_idx, batch in enumerate(dl):
|
for batch_idx, batch in enumerate(dl):
|
||||||
texts = batch["supervisions"]["text"]
|
texts = batch["supervisions"]["text"]
|
||||||
# texts = [list(str(text).replace(" ", "")) for text in texts]
|
zh_texts = []
|
||||||
|
en_texts = []
|
||||||
for i in range(len(texts)):
|
for i in range(len(texts)):
|
||||||
text = texts[i]
|
text = texts[i]
|
||||||
chars = pattern.split(text.upper())
|
chars = pattern.split(text.upper())
|
||||||
chars_new = []
|
chars_new = []
|
||||||
|
zh_text = []
|
||||||
|
en_text = []
|
||||||
for char in chars:
|
for char in chars:
|
||||||
if char != "":
|
if char != "":
|
||||||
chars_new.extend(char.strip().split(" "))
|
tokens = char.strip().split(" ")
|
||||||
|
chars_new.extend(tokens)
|
||||||
|
for token in tokens:
|
||||||
|
zh_text.extend(re.findall(zh_char, token))
|
||||||
|
en_text.extend(re.findall(en_letter, token))
|
||||||
|
zh_texts.append(zh_text)
|
||||||
|
en_texts.append(en_text)
|
||||||
texts[i] = chars_new
|
texts[i] = chars_new
|
||||||
hyps_dict = decode_one_batch(
|
hyps_dict = decode_one_batch(
|
||||||
params=params,
|
params=params,
|
||||||
@ -431,13 +480,25 @@ def decode_dataset(
|
|||||||
sp=sp,
|
sp=sp,
|
||||||
)
|
)
|
||||||
|
|
||||||
for name, hyps in hyps_dict.items():
|
for name, hyps_texts in hyps_dict.items():
|
||||||
this_batch = []
|
this_batch = []
|
||||||
|
this_batch_zh = []
|
||||||
|
this_batch_en = []
|
||||||
|
# print(hyps_texts)
|
||||||
|
hyps, zh_hyps, en_hyps = hyps_texts
|
||||||
assert len(hyps) == len(texts)
|
assert len(hyps) == len(texts)
|
||||||
for hyp_words, ref_text in zip(hyps, texts):
|
for hyp_words, ref_text in zip(hyps, texts):
|
||||||
this_batch.append((ref_text, hyp_words))
|
this_batch.append((ref_text, hyp_words))
|
||||||
|
|
||||||
|
for hyp_words, ref_text in zip(zh_hyps, zh_texts):
|
||||||
|
this_batch_zh.append((ref_text, hyp_words))
|
||||||
|
|
||||||
|
for hyp_words, ref_text in zip(en_hyps, en_texts):
|
||||||
|
this_batch_en.append((ref_text, hyp_words))
|
||||||
|
|
||||||
results[name].extend(this_batch)
|
results[name].extend(this_batch)
|
||||||
|
zh_results[name + "_zh"].extend(this_batch_zh)
|
||||||
|
en_results[name + "_en"].extend(this_batch_en)
|
||||||
|
|
||||||
num_cuts += len(texts)
|
num_cuts += len(texts)
|
||||||
|
|
||||||
@ -447,7 +508,7 @@ def decode_dataset(
|
|||||||
logging.info(
|
logging.info(
|
||||||
f"batch {batch_str}, cuts processed until now is {num_cuts}"
|
f"batch {batch_str}, cuts processed until now is {num_cuts}"
|
||||||
)
|
)
|
||||||
return results
|
return results, zh_results, en_results
|
||||||
|
|
||||||
|
|
||||||
def save_results(
|
def save_results(
|
||||||
@ -663,7 +724,7 @@ def main():
|
|||||||
test_dl = [dev_dl, test_dl]
|
test_dl = [dev_dl, test_dl]
|
||||||
|
|
||||||
for test_set, test_dl in zip(test_sets, test_dl):
|
for test_set, test_dl in zip(test_sets, test_dl):
|
||||||
results_dict = decode_dataset(
|
results_dict, zh_results_dict, en_results_dict = decode_dataset(
|
||||||
dl=test_dl,
|
dl=test_dl,
|
||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
@ -676,6 +737,16 @@ def main():
|
|||||||
test_set_name=test_set,
|
test_set_name=test_set,
|
||||||
results_dict=results_dict,
|
results_dict=results_dict,
|
||||||
)
|
)
|
||||||
|
save_results(
|
||||||
|
params=params,
|
||||||
|
test_set_name=test_set,
|
||||||
|
results_dict=zh_results_dict,
|
||||||
|
)
|
||||||
|
save_results(
|
||||||
|
params=params,
|
||||||
|
test_set_name=test_set,
|
||||||
|
results_dict=en_results_dict,
|
||||||
|
)
|
||||||
|
|
||||||
logging.info("Done!")
|
logging.info("Done!")
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user