diff --git a/egs/aishell/ASR/zipformer/ctc_decode.py b/egs/aishell/ASR/zipformer/ctc_decode.py index a3dcb6930..84f7084e4 100755 --- a/egs/aishell/ASR/zipformer/ctc_decode.py +++ b/egs/aishell/ASR/zipformer/ctc_decode.py @@ -85,13 +85,11 @@ from icefall.decode import ( from icefall.lexicon import Lexicon from icefall.utils import ( AttributeDict, - DecodingResults, make_pad_mask, setup_logger, - store_transcripts_and_timestamps_withoutref, + store_transcripts, str2bool, write_error_stats, - parse_hyp_and_timestamp_ch, ) LOG_EPS = math.log(1e-10) @@ -339,29 +337,25 @@ def decode_one_batch( ctc_output = model.ctc_output(encoder_out) # (N, T, C) + hyp_tokens = [] hyps = [] if params.decoding_method == "ctc-greedy-search" and params.max_sym_per_frame == 1: - res = ctc_greedy_search( + hyp_tokens = ctc_greedy_search( ctc_output=ctc_output, encoder_out_lens=encoder_out_lens, - return_timestamps = True, ) else: raise ValueError( f"Unsupported decoding method: {params.decoding_method}" ) - hyps, timestamps = parse_hyp_and_timestamp_ch( - res=res, - subsampling_factor=params.subsampling_factor, - word_table = lexicon.token_table, - # frame_shift_ms=params.frame_shift_ms, - ) + for i in range(encoder_out.size(0)): + hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) key = f"blank_penalty_{params.blank_penalty}" if params.decoding_method == "ctc-greedy-search": - return {"ctc-greedy-search_" + key: (hyps, timestamps)} + return {"ctc-greedy-search_" + key: hyps} elif "fast_beam_search" in params.decoding_method: key += f"_beam_{params.beam}_" key += f"max_contexts_{params.max_contexts}_" @@ -373,9 +367,9 @@ def decode_one_batch( key += f"_ilme_scale_{params.ilme_scale}" key += f"_ngram_lm_scale_{params.ngram_lm_scale}" - return {key: (hyps, timestamps)} + return {key: hyps} else: - return {f"beam_size_{params.beam_size}_" + key: (hyps, timestamps)} + return {f"beam_size_{params.beam_size}_" + key: hyps} def decode_dataset( @@ -385,8 +379,7 @@ def decode_dataset( lexicon: Lexicon, graph_compiler: CharCtcTrainingGraphCompiler, decoding_graph: Optional[k2.Fsa] = None, - with_timestamp: bool = False, -) -> Dict[str, List[Tuple[str, List[str], List[str], List[Tuple[float, float]]]]]: +) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: """Decode dataset. Args: @@ -400,13 +393,11 @@ def decode_dataset( The decoding graph. Can be either a `k2.trivial_graph` or LG, Used only when --decoding_method is fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. - with_timestamp: - Whether to decode with timestamp. Returns: Return a dict, whose key may be "greedy_search" if greedy search is used, or it may be "beam_7" if beam size of 7 is used. - Its value is a list of tuples. Each tuple contains 4 elements: - Respectively, they are cut_id, the reference transcript, the predicted result and the decoded_timestamps. + Its value is a list of tuples. Each tuple contains 3 elements: + Respectively, they are cut_id, the reference transcript, and the predicted result. """ num_cuts = 0 @@ -434,26 +425,12 @@ def decode_dataset( decoding_graph=decoding_graph, batch=batch, ) - if with_timestamp: - - for name, (hyps, timestamps_hyp) in hyps_dict.items(): - this_batch = [] - assert len(hyps) == len(texts) and len(timestamps_hyp) == len(hyps) - for cut_id, hyp_words, ref_text, time_hyp in zip( - cut_ids, hyps, texts, timestamps_hyp - ): - this_batch.append((cut_id, ref_text, hyp_words, time_hyp)) - - results[name].extend(this_batch) - else: - - for name, hyps in hyps_dict.items(): - this_batch = [] - assert len(hyps) == len(texts) - for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): - this_batch.append((cut_id, ref_text, hyp_words)) - - results[name].extend(this_batch) + for name, hyps in hyps_dict.items(): + this_batch = [] + assert len(hyps) == len(texts) + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): + this_batch.append((cut_id, ref_text, hyp_words)) + results[name].extend(this_batch) num_cuts += len(texts) @@ -467,7 +444,7 @@ def decode_dataset( def save_results( params: AttributeDict, test_set_name: str, - results_dict: Dict[str, List[Tuple[str, List[str], List[str], List[Tuple[float, float]]]]], + results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], ): test_set_wers = dict() for key, results in results_dict.items(): @@ -475,7 +452,7 @@ def save_results( params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" ) results = sorted(results) - store_transcripts_and_timestamps_withoutref(filename=recog_path, texts=results) + store_transcripts(filename=recog_path, texts=results, char_level = True) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned @@ -483,12 +460,11 @@ def save_results( errs_filename = ( params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" ) - result_without_timestamp = [(res[0], res[1], res[2]) for res in results] with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", - result_without_timestamp, + results, enable_log=True, compute_CER=True, ) @@ -720,7 +696,6 @@ def main(): lexicon=lexicon, graph_compiler=graph_compiler, decoding_graph=decoding_graph, - with_timestamp=True, ) save_results(