removed timestamp_decode in ctc_decode

This commit is contained in:
hhzzff 2025-07-04 11:13:56 +08:00
parent 1c5c0c6a09
commit 1d6530cfb5

View File

@ -85,13 +85,11 @@ from icefall.decode import (
from icefall.lexicon import Lexicon from icefall.lexicon import Lexicon
from icefall.utils import ( from icefall.utils import (
AttributeDict, AttributeDict,
DecodingResults,
make_pad_mask, make_pad_mask,
setup_logger, setup_logger,
store_transcripts_and_timestamps_withoutref, store_transcripts,
str2bool, str2bool,
write_error_stats, write_error_stats,
parse_hyp_and_timestamp_ch,
) )
LOG_EPS = math.log(1e-10) LOG_EPS = math.log(1e-10)
@ -339,29 +337,25 @@ def decode_one_batch(
ctc_output = model.ctc_output(encoder_out) # (N, T, C) ctc_output = model.ctc_output(encoder_out) # (N, T, C)
hyp_tokens = []
hyps = [] hyps = []
if params.decoding_method == "ctc-greedy-search" and params.max_sym_per_frame == 1: if params.decoding_method == "ctc-greedy-search" and params.max_sym_per_frame == 1:
res = ctc_greedy_search( hyp_tokens = ctc_greedy_search(
ctc_output=ctc_output, ctc_output=ctc_output,
encoder_out_lens=encoder_out_lens, encoder_out_lens=encoder_out_lens,
return_timestamps = True,
) )
else: else:
raise ValueError( raise ValueError(
f"Unsupported decoding method: {params.decoding_method}" f"Unsupported decoding method: {params.decoding_method}"
) )
hyps, timestamps = parse_hyp_and_timestamp_ch( for i in range(encoder_out.size(0)):
res=res, hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
subsampling_factor=params.subsampling_factor,
word_table = lexicon.token_table,
# frame_shift_ms=params.frame_shift_ms,
)
key = f"blank_penalty_{params.blank_penalty}" key = f"blank_penalty_{params.blank_penalty}"
if params.decoding_method == "ctc-greedy-search": if params.decoding_method == "ctc-greedy-search":
return {"ctc-greedy-search_" + key: (hyps, timestamps)} return {"ctc-greedy-search_" + key: hyps}
elif "fast_beam_search" in params.decoding_method: elif "fast_beam_search" in params.decoding_method:
key += f"_beam_{params.beam}_" key += f"_beam_{params.beam}_"
key += f"max_contexts_{params.max_contexts}_" key += f"max_contexts_{params.max_contexts}_"
@ -373,9 +367,9 @@ def decode_one_batch(
key += f"_ilme_scale_{params.ilme_scale}" key += f"_ilme_scale_{params.ilme_scale}"
key += f"_ngram_lm_scale_{params.ngram_lm_scale}" key += f"_ngram_lm_scale_{params.ngram_lm_scale}"
return {key: (hyps, timestamps)} return {key: hyps}
else: else:
return {f"beam_size_{params.beam_size}_" + key: (hyps, timestamps)} return {f"beam_size_{params.beam_size}_" + key: hyps}
def decode_dataset( def decode_dataset(
@ -385,8 +379,7 @@ def decode_dataset(
lexicon: Lexicon, lexicon: Lexicon,
graph_compiler: CharCtcTrainingGraphCompiler, graph_compiler: CharCtcTrainingGraphCompiler,
decoding_graph: Optional[k2.Fsa] = None, decoding_graph: Optional[k2.Fsa] = None,
with_timestamp: bool = False, ) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
) -> Dict[str, List[Tuple[str, List[str], List[str], List[Tuple[float, float]]]]]:
"""Decode dataset. """Decode dataset.
Args: Args:
@ -400,13 +393,11 @@ def decode_dataset(
The decoding graph. Can be either a `k2.trivial_graph` or LG, Used The decoding graph. Can be either a `k2.trivial_graph` or LG, Used
only when --decoding_method is fast_beam_search, fast_beam_search_nbest, only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
with_timestamp:
Whether to decode with timestamp.
Returns: Returns:
Return a dict, whose key may be "greedy_search" if greedy search Return a dict, whose key may be "greedy_search" if greedy search
is used, or it may be "beam_7" if beam size of 7 is used. is used, or it may be "beam_7" if beam size of 7 is used.
Its value is a list of tuples. Each tuple contains 4 elements: Its value is a list of tuples. Each tuple contains 3 elements:
Respectively, they are cut_id, the reference transcript, the predicted result and the decoded_timestamps. Respectively, they are cut_id, the reference transcript, and the predicted result.
""" """
num_cuts = 0 num_cuts = 0
@ -434,25 +425,11 @@ def decode_dataset(
decoding_graph=decoding_graph, decoding_graph=decoding_graph,
batch=batch, batch=batch,
) )
if with_timestamp:
for name, (hyps, timestamps_hyp) in hyps_dict.items():
this_batch = []
assert len(hyps) == len(texts) and len(timestamps_hyp) == len(hyps)
for cut_id, hyp_words, ref_text, time_hyp in zip(
cut_ids, hyps, texts, timestamps_hyp
):
this_batch.append((cut_id, ref_text, hyp_words, time_hyp))
results[name].extend(this_batch)
else:
for name, hyps in hyps_dict.items(): for name, hyps in hyps_dict.items():
this_batch = [] this_batch = []
assert len(hyps) == len(texts) assert len(hyps) == len(texts)
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
this_batch.append((cut_id, ref_text, hyp_words)) this_batch.append((cut_id, ref_text, hyp_words))
results[name].extend(this_batch) results[name].extend(this_batch)
num_cuts += len(texts) num_cuts += len(texts)
@ -467,7 +444,7 @@ def decode_dataset(
def save_results( def save_results(
params: AttributeDict, params: AttributeDict,
test_set_name: str, test_set_name: str,
results_dict: Dict[str, List[Tuple[str, List[str], List[str], List[Tuple[float, float]]]]], results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
): ):
test_set_wers = dict() test_set_wers = dict()
for key, results in results_dict.items(): for key, results in results_dict.items():
@ -475,7 +452,7 @@ def save_results(
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
) )
results = sorted(results) results = sorted(results)
store_transcripts_and_timestamps_withoutref(filename=recog_path, texts=results) store_transcripts(filename=recog_path, texts=results, char_level = True)
logging.info(f"The transcripts are stored in {recog_path}") logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out WERs, per-word error statistics and aligned # The following prints out WERs, per-word error statistics and aligned
@ -483,12 +460,11 @@ def save_results(
errs_filename = ( errs_filename = (
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
) )
result_without_timestamp = [(res[0], res[1], res[2]) for res in results]
with open(errs_filename, "w") as f: with open(errs_filename, "w") as f:
wer = write_error_stats( wer = write_error_stats(
f, f,
f"{test_set_name}-{key}", f"{test_set_name}-{key}",
result_without_timestamp, results,
enable_log=True, enable_log=True,
compute_CER=True, compute_CER=True,
) )
@ -720,7 +696,6 @@ def main():
lexicon=lexicon, lexicon=lexicon,
graph_compiler=graph_compiler, graph_compiler=graph_compiler,
decoding_graph=decoding_graph, decoding_graph=decoding_graph,
with_timestamp=True,
) )
save_results( save_results(