mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-08 09:32:20 +00:00
removed timestamp_decode in ctc_decode
This commit is contained in:
parent
1c5c0c6a09
commit
1d6530cfb5
@ -85,13 +85,11 @@ from icefall.decode import (
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
DecodingResults,
|
||||
make_pad_mask,
|
||||
setup_logger,
|
||||
store_transcripts_and_timestamps_withoutref,
|
||||
store_transcripts,
|
||||
str2bool,
|
||||
write_error_stats,
|
||||
parse_hyp_and_timestamp_ch,
|
||||
)
|
||||
|
||||
LOG_EPS = math.log(1e-10)
|
||||
@ -339,29 +337,25 @@ def decode_one_batch(
|
||||
|
||||
ctc_output = model.ctc_output(encoder_out) # (N, T, C)
|
||||
|
||||
hyp_tokens = []
|
||||
hyps = []
|
||||
|
||||
if params.decoding_method == "ctc-greedy-search" and params.max_sym_per_frame == 1:
|
||||
res = ctc_greedy_search(
|
||||
hyp_tokens = ctc_greedy_search(
|
||||
ctc_output=ctc_output,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
return_timestamps = True,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported decoding method: {params.decoding_method}"
|
||||
)
|
||||
|
||||
hyps, timestamps = parse_hyp_and_timestamp_ch(
|
||||
res=res,
|
||||
subsampling_factor=params.subsampling_factor,
|
||||
word_table = lexicon.token_table,
|
||||
# frame_shift_ms=params.frame_shift_ms,
|
||||
)
|
||||
for i in range(encoder_out.size(0)):
|
||||
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
|
||||
|
||||
key = f"blank_penalty_{params.blank_penalty}"
|
||||
if params.decoding_method == "ctc-greedy-search":
|
||||
return {"ctc-greedy-search_" + key: (hyps, timestamps)}
|
||||
return {"ctc-greedy-search_" + key: hyps}
|
||||
elif "fast_beam_search" in params.decoding_method:
|
||||
key += f"_beam_{params.beam}_"
|
||||
key += f"max_contexts_{params.max_contexts}_"
|
||||
@ -373,9 +367,9 @@ def decode_one_batch(
|
||||
key += f"_ilme_scale_{params.ilme_scale}"
|
||||
key += f"_ngram_lm_scale_{params.ngram_lm_scale}"
|
||||
|
||||
return {key: (hyps, timestamps)}
|
||||
return {key: hyps}
|
||||
else:
|
||||
return {f"beam_size_{params.beam_size}_" + key: (hyps, timestamps)}
|
||||
return {f"beam_size_{params.beam_size}_" + key: hyps}
|
||||
|
||||
|
||||
def decode_dataset(
|
||||
@ -385,8 +379,7 @@ def decode_dataset(
|
||||
lexicon: Lexicon,
|
||||
graph_compiler: CharCtcTrainingGraphCompiler,
|
||||
decoding_graph: Optional[k2.Fsa] = None,
|
||||
with_timestamp: bool = False,
|
||||
) -> Dict[str, List[Tuple[str, List[str], List[str], List[Tuple[float, float]]]]]:
|
||||
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
|
||||
"""Decode dataset.
|
||||
|
||||
Args:
|
||||
@ -400,13 +393,11 @@ def decode_dataset(
|
||||
The decoding graph. Can be either a `k2.trivial_graph` or LG, Used
|
||||
only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
|
||||
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
|
||||
with_timestamp:
|
||||
Whether to decode with timestamp.
|
||||
Returns:
|
||||
Return a dict, whose key may be "greedy_search" if greedy search
|
||||
is used, or it may be "beam_7" if beam size of 7 is used.
|
||||
Its value is a list of tuples. Each tuple contains 4 elements:
|
||||
Respectively, they are cut_id, the reference transcript, the predicted result and the decoded_timestamps.
|
||||
Its value is a list of tuples. Each tuple contains 3 elements:
|
||||
Respectively, they are cut_id, the reference transcript, and the predicted result.
|
||||
"""
|
||||
num_cuts = 0
|
||||
|
||||
@ -434,26 +425,12 @@ def decode_dataset(
|
||||
decoding_graph=decoding_graph,
|
||||
batch=batch,
|
||||
)
|
||||
if with_timestamp:
|
||||
|
||||
for name, (hyps, timestamps_hyp) in hyps_dict.items():
|
||||
this_batch = []
|
||||
assert len(hyps) == len(texts) and len(timestamps_hyp) == len(hyps)
|
||||
for cut_id, hyp_words, ref_text, time_hyp in zip(
|
||||
cut_ids, hyps, texts, timestamps_hyp
|
||||
):
|
||||
this_batch.append((cut_id, ref_text, hyp_words, time_hyp))
|
||||
|
||||
results[name].extend(this_batch)
|
||||
else:
|
||||
|
||||
for name, hyps in hyps_dict.items():
|
||||
this_batch = []
|
||||
assert len(hyps) == len(texts)
|
||||
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
|
||||
this_batch.append((cut_id, ref_text, hyp_words))
|
||||
|
||||
results[name].extend(this_batch)
|
||||
for name, hyps in hyps_dict.items():
|
||||
this_batch = []
|
||||
assert len(hyps) == len(texts)
|
||||
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
|
||||
this_batch.append((cut_id, ref_text, hyp_words))
|
||||
results[name].extend(this_batch)
|
||||
|
||||
num_cuts += len(texts)
|
||||
|
||||
@ -467,7 +444,7 @@ def decode_dataset(
|
||||
def save_results(
|
||||
params: AttributeDict,
|
||||
test_set_name: str,
|
||||
results_dict: Dict[str, List[Tuple[str, List[str], List[str], List[Tuple[float, float]]]]],
|
||||
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
|
||||
):
|
||||
test_set_wers = dict()
|
||||
for key, results in results_dict.items():
|
||||
@ -475,7 +452,7 @@ def save_results(
|
||||
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
)
|
||||
results = sorted(results)
|
||||
store_transcripts_and_timestamps_withoutref(filename=recog_path, texts=results)
|
||||
store_transcripts(filename=recog_path, texts=results, char_level = True)
|
||||
logging.info(f"The transcripts are stored in {recog_path}")
|
||||
|
||||
# The following prints out WERs, per-word error statistics and aligned
|
||||
@ -483,12 +460,11 @@ def save_results(
|
||||
errs_filename = (
|
||||
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
)
|
||||
result_without_timestamp = [(res[0], res[1], res[2]) for res in results]
|
||||
with open(errs_filename, "w") as f:
|
||||
wer = write_error_stats(
|
||||
f,
|
||||
f"{test_set_name}-{key}",
|
||||
result_without_timestamp,
|
||||
results,
|
||||
enable_log=True,
|
||||
compute_CER=True,
|
||||
)
|
||||
@ -720,7 +696,6 @@ def main():
|
||||
lexicon=lexicon,
|
||||
graph_compiler=graph_compiler,
|
||||
decoding_graph=decoding_graph,
|
||||
with_timestamp=True,
|
||||
)
|
||||
|
||||
save_results(
|
||||
|
Loading…
x
Reference in New Issue
Block a user