removed timestamp_decode in ctc_decode

This commit is contained in:
hhzzff 2025-07-04 11:13:56 +08:00
parent 1c5c0c6a09
commit 1d6530cfb5

View File

@ -85,13 +85,11 @@ from icefall.decode import (
from icefall.lexicon import Lexicon
from icefall.utils import (
AttributeDict,
DecodingResults,
make_pad_mask,
setup_logger,
store_transcripts_and_timestamps_withoutref,
store_transcripts,
str2bool,
write_error_stats,
parse_hyp_and_timestamp_ch,
)
LOG_EPS = math.log(1e-10)
@ -339,29 +337,25 @@ def decode_one_batch(
ctc_output = model.ctc_output(encoder_out) # (N, T, C)
hyp_tokens = []
hyps = []
if params.decoding_method == "ctc-greedy-search" and params.max_sym_per_frame == 1:
res = ctc_greedy_search(
hyp_tokens = ctc_greedy_search(
ctc_output=ctc_output,
encoder_out_lens=encoder_out_lens,
return_timestamps = True,
)
else:
raise ValueError(
f"Unsupported decoding method: {params.decoding_method}"
)
hyps, timestamps = parse_hyp_and_timestamp_ch(
res=res,
subsampling_factor=params.subsampling_factor,
word_table = lexicon.token_table,
# frame_shift_ms=params.frame_shift_ms,
)
for i in range(encoder_out.size(0)):
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
key = f"blank_penalty_{params.blank_penalty}"
if params.decoding_method == "ctc-greedy-search":
return {"ctc-greedy-search_" + key: (hyps, timestamps)}
return {"ctc-greedy-search_" + key: hyps}
elif "fast_beam_search" in params.decoding_method:
key += f"_beam_{params.beam}_"
key += f"max_contexts_{params.max_contexts}_"
@ -373,9 +367,9 @@ def decode_one_batch(
key += f"_ilme_scale_{params.ilme_scale}"
key += f"_ngram_lm_scale_{params.ngram_lm_scale}"
return {key: (hyps, timestamps)}
return {key: hyps}
else:
return {f"beam_size_{params.beam_size}_" + key: (hyps, timestamps)}
return {f"beam_size_{params.beam_size}_" + key: hyps}
def decode_dataset(
@ -385,8 +379,7 @@ def decode_dataset(
lexicon: Lexicon,
graph_compiler: CharCtcTrainingGraphCompiler,
decoding_graph: Optional[k2.Fsa] = None,
with_timestamp: bool = False,
) -> Dict[str, List[Tuple[str, List[str], List[str], List[Tuple[float, float]]]]]:
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
"""Decode dataset.
Args:
@ -400,13 +393,11 @@ def decode_dataset(
The decoding graph. Can be either a `k2.trivial_graph` or LG, Used
only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
with_timestamp:
Whether to decode with timestamp.
Returns:
Return a dict, whose key may be "greedy_search" if greedy search
is used, or it may be "beam_7" if beam size of 7 is used.
Its value is a list of tuples. Each tuple contains 4 elements:
Respectively, they are cut_id, the reference transcript, the predicted result and the decoded_timestamps.
Its value is a list of tuples. Each tuple contains 3 elements:
Respectively, they are cut_id, the reference transcript, and the predicted result.
"""
num_cuts = 0
@ -434,26 +425,12 @@ def decode_dataset(
decoding_graph=decoding_graph,
batch=batch,
)
if with_timestamp:
for name, (hyps, timestamps_hyp) in hyps_dict.items():
this_batch = []
assert len(hyps) == len(texts) and len(timestamps_hyp) == len(hyps)
for cut_id, hyp_words, ref_text, time_hyp in zip(
cut_ids, hyps, texts, timestamps_hyp
):
this_batch.append((cut_id, ref_text, hyp_words, time_hyp))
results[name].extend(this_batch)
else:
for name, hyps in hyps_dict.items():
this_batch = []
assert len(hyps) == len(texts)
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
this_batch.append((cut_id, ref_text, hyp_words))
results[name].extend(this_batch)
for name, hyps in hyps_dict.items():
this_batch = []
assert len(hyps) == len(texts)
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
this_batch.append((cut_id, ref_text, hyp_words))
results[name].extend(this_batch)
num_cuts += len(texts)
@ -467,7 +444,7 @@ def decode_dataset(
def save_results(
params: AttributeDict,
test_set_name: str,
results_dict: Dict[str, List[Tuple[str, List[str], List[str], List[Tuple[float, float]]]]],
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
):
test_set_wers = dict()
for key, results in results_dict.items():
@ -475,7 +452,7 @@ def save_results(
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
)
results = sorted(results)
store_transcripts_and_timestamps_withoutref(filename=recog_path, texts=results)
store_transcripts(filename=recog_path, texts=results, char_level = True)
logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out WERs, per-word error statistics and aligned
@ -483,12 +460,11 @@ def save_results(
errs_filename = (
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
)
result_without_timestamp = [(res[0], res[1], res[2]) for res in results]
with open(errs_filename, "w") as f:
wer = write_error_stats(
f,
f"{test_set_name}-{key}",
result_without_timestamp,
results,
enable_log=True,
compute_CER=True,
)
@ -720,7 +696,6 @@ def main():
lexicon=lexicon,
graph_compiler=graph_compiler,
decoding_graph=decoding_graph,
with_timestamp=True,
)
save_results(