remove space if existing

fix
This commit is contained in:
yfyeung 2025-04-29 10:32:12 -07:00 committed by Your Name
parent d1c336f589
commit 26c022665b
2 changed files with 15 additions and 18 deletions

View File

@ -382,9 +382,7 @@ def decode_dataset(
for lm_scale, hyps in hyps_dict.items(): for lm_scale, hyps in hyps_dict.items():
this_batch = [] this_batch = []
assert len(hyps) == len(texts) assert len(hyps) == len(texts)
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): for cut_id, hyp_text, ref_text in zip(cut_ids, hyps, texts):
print(f"ref: {''.join(ref_text)}")
print(f"hyp: {''.join(hyp_words)}")
this_batch.append((cut_id, ref_text, hyp_words)) this_batch.append((cut_id, ref_text, hyp_words))
results[lm_scale].extend(this_batch) results[lm_scale].extend(this_batch)
@ -401,40 +399,38 @@ def decode_dataset(
def save_results( def save_results(
params: AttributeDict, params: AttributeDict,
test_set_name: str, test_set_name: str,
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], results_dict: Dict[str, List[Tuple[List[int], List[int]]]],
): ):
enable_log = True
test_set_wers = dict() test_set_wers = dict()
for key, results in results_dict.items(): for key, results in results_dict.items():
recog_path = ( recog_path = (
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
) )
results = sorted(results) results = sorted(results)
store_transcripts(filename=recog_path, texts=results) store_transcripts(filename=recog_path, texts=results, char_level=True)
if enable_log: logging.info(f"The transcripts are stored in {recog_path}")
logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out WERs, per-word error statistics and aligned # The following prints out CERs, per-word error statistics and aligned
# ref/hyp pairs. # ref/hyp pairs.
errs_filename = ( errs_filename = (
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
) )
# we compute CER for aishell dataset.
results_char = []
for res in results:
results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
with open(errs_filename, "w") as f: with open(errs_filename, "w") as f:
wer = write_error_stats( wer = write_error_stats(
f, f"{test_set_name}-{key}", results_char, enable_log=enable_log f,
f"{test_set_name}-{key}",
results,
enable_log=True,
compute_CER=True,
) )
test_set_wers[key] = wer test_set_wers[key] = wer
if enable_log: logging.info("Wrote detailed error stats to {}".format(errs_filename))
logging.info("Wrote detailed error stats to {}".format(errs_filename))
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = params.res_dir / f"cer-summary-{test_set_name}-{params.suffix}.txt" errs_info = (
params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
)
with open(errs_info, "w") as f: with open(errs_info, "w") as f:
print("settings\tCER", file=f) print("settings\tCER", file=f)
for key, val in test_set_wers: for key, val in test_set_wers:

View File

@ -339,6 +339,7 @@ def compute_loss(
messages = [] messages = []
for i, text in enumerate(texts): for i, text in enumerate(texts):
text = text.replace(" ", "")
message = [ message = [
{"role": "user", "content": f"{DEFAULT_SPEECH_TOKEN}请转写音频为文字"}, {"role": "user", "content": f"{DEFAULT_SPEECH_TOKEN}请转写音频为文字"},
{"role": "assistant", "content": text}, {"role": "assistant", "content": text},