diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/decode.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/decode.py index 78e1f4096..7f4d000fc 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/decode.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/decode.py @@ -98,10 +98,12 @@ from icefall.checkpoint import ( ) from icefall.utils import ( AttributeDict, + DecodingResults, + parse_hyp_and_timestamp, setup_logger, - store_transcripts, + store_transcripts_and_timestamps, str2bool, - write_error_stats, + write_error_stats_with_timestamps, ) LOG_EPS = math.log(1e-10) @@ -237,7 +239,7 @@ def decode_one_batch( sp: spm.SentencePieceProcessor, batch: dict, decoding_graph: Optional[k2.Fsa] = None, -) -> Dict[str, List[List[str]]]: +) -> Dict[str, Tuple[List[List[str]], List[List[float]]]]: """Decode one batch and return the result in a dict. The dict has the following format: @@ -287,7 +289,7 @@ def decode_one_batch( hyps = [] if params.decoding_method == "fast_beam_search": - hyp_tokens = fast_beam_search_one_best( + res = fast_beam_search_one_best( model=model, decoding_graph=decoding_graph, encoder_out=encoder_out, @@ -295,63 +297,74 @@ def decode_one_batch( beam=params.beam, max_contexts=params.max_contexts, max_states=params.max_states, + return_timestamps=True, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: - hyp_tokens = greedy_search_batch( + res = greedy_search_batch( model=model, encoder_out=encoder_out, encoder_out_lens=encoder_out_lens, + return_timestamps=True, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) elif params.decoding_method == "modified_beam_search": - hyp_tokens = modified_beam_search( + res = modified_beam_search( model=model, encoder_out=encoder_out, encoder_out_lens=encoder_out_lens, beam=params.beam_size, + return_timestamps=True, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) else: batch_size = encoder_out.size(0) - + tokens = [] + timestamps = [] for i in range(batch_size): # fmt: off encoder_out_i = encoder_out[i:i + 1, :encoder_out_lens[i]] # fmt: on if params.decoding_method == "greedy_search": - hyp = greedy_search( + res = greedy_search( model=model, encoder_out=encoder_out_i, max_sym_per_frame=params.max_sym_per_frame, + return_timestamps=True, ) elif params.decoding_method == "beam_search": - hyp = beam_search( + res = beam_search( model=model, encoder_out=encoder_out_i, beam=params.beam_size, + return_timestamps=True, ) else: raise ValueError( f"Unsupported decoding method: {params.decoding_method}" ) - hyps.append(sp.decode(hyp).split()) + tokens.extend(res.tokens) + timestamps.extend(res.timestamps) + res = DecodingResults(hyps=tokens, timestamps=timestamps) + + hyps, timestamps = parse_hyp_and_timestamp( + decoding_method=params.decoding_method, + res=res, + sp=sp, + subsampling_factor=params.subsampling_factor, + frame_shift_ms=params.frame_shift_ms, + word_table=word_table, + ) if params.decoding_method == "greedy_search": - return {"greedy_search": hyps} + return {"greedy_search": (hyps, timestamps)} elif params.decoding_method == "fast_beam_search": return { ( f"beam_{params.beam}_" f"max_contexts_{params.max_contexts}_" f"max_states_{params.max_states}" - ): hyps + ): (hyps, timestamps) } else: - return {f"beam_size_{params.beam_size}": hyps} + return {f"beam_size_{params.beam_size}": (hyps, timestamps)} def decode_dataset( @@ -360,7 +373,7 @@ def decode_dataset( model: nn.Module, sp: spm.SentencePieceProcessor, decoding_graph: Optional[k2.Fsa] = None, -) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: +) ->Dict[str, List[Tuple[str, List[str], List[str], List[float], List[float]]]]: """Decode dataset. Args: @@ -378,9 +391,12 @@ def decode_dataset( Returns: Return a dict, whose key may be "greedy_search" if greedy search is used, or it may be "beam_7" if beam size of 7 is used. - Its value is a list of tuples. Each tuple contains two elements: - The first is the reference transcript, and the second is the - predicted result. + Its value is a list of tuples. Each tuple contains five elements: + - cut_id + - reference transcript + - predicted result + - timestamp of reference transcript + - timestamp of predicted result """ num_cuts = 0 @@ -390,14 +406,26 @@ def decode_dataset( num_batches = "?" if params.decoding_method == "greedy_search": - log_interval = 100 + log_interval = 50 else: - log_interval = 2 + log_interval = 20 results = defaultdict(list) for batch_idx, batch in enumerate(dl): texts = batch["supervisions"]["text"] cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] + + timestamps_ref = [] + for cut in batch["supervisions"]["cut"]: + for s in cut.supervisions: + time = [] + if s.alignment is not None and "word" in s.alignment: + time = [ + aliword.start + for aliword in s.alignment["word"] + if aliword.symbol != "" + ] + timestamps_ref.append(time) hyps_dict = decode_one_batch( params=params, @@ -407,12 +435,16 @@ def decode_dataset( batch=batch, ) - for name, hyps in hyps_dict.items(): + for name, (hyps, timestamps_hyp) in hyps_dict.items(): this_batch = [] - assert len(hyps) == len(texts) - for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): + assert len(hyps) == len(texts) and len(timestamps_hyp) == len( + timestamps_ref + ) + for cut_id, hyp_words, ref_text, time_hyp, time_ref in zip( + cut_ids, hyps, texts, timestamps_hyp, timestamps_ref + ): ref_words = ref_text.split() - this_batch.append((cut_id, ref_words, hyp_words)) + this_batch.append((cut_id, ref_words, hyp_words, time_ref, time_hyp)) results[name].extend(this_batch) @@ -428,15 +460,19 @@ def decode_dataset( def save_results( params: AttributeDict, test_set_name: str, - results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], + results_dict: Dict[ + str, + List[Tuple[List[str], List[str], List[str], List[float], List[float]]], + ], ): test_set_wers = dict() + test_set_delays = dict() for key, results in results_dict.items(): recog_path = ( params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" ) results = sorted(results) - store_transcripts(filename=recog_path, texts=results) + store_transcripts_and_timestamps(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned @@ -445,10 +481,11 @@ def save_results( params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_filename, "w") as f: - wer = write_error_stats( + wer, mean_delay, var_delay = write_error_stats_with_timestamps( f, f"{test_set_name}-{key}", results, enable_log=True ) test_set_wers[key] = wer + test_set_delays[key] = (mean_delay, var_delay) logging.info("Wrote detailed error stats to {}".format(errs_filename)) @@ -461,6 +498,19 @@ def save_results( for key, val in test_set_wers: print("{}\t{}".format(key, val), file=f) + test_set_delays = sorted(test_set_delays.items(), key=lambda x: x[1][0]) + delays_info = ( + params.res_dir + / f"symbol-delay-summary-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(delays_info, "w") as f: + print("settings\tsymbol-delay", file=f) + for key, val in test_set_delays: + print( + "{}\tmean: {}s, variance: {}".format(key, val[0], val[1]), + file=f, + ) + s = "\nFor {}, WER of different settings are:\n".format(test_set_name) note = "\tbest for {}".format(test_set_name) for key, val in test_set_wers: @@ -468,6 +518,13 @@ def save_results( note = "" logging.info(s) + s = "\nFor {}, symbol-delay of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_delays: + s += "{}\tmean: {}s, variance: {}{}\n".format(key, val[0], val[1], note) + note = "" + logging.info(s) + @torch.no_grad() def main(): @@ -517,7 +574,7 @@ def main(): sp = spm.SentencePieceProcessor() sp.load(params.bpe_model) - # and is defined in local/train_bpe_model.py + # and are defined in local/train_bpe_model.py params.blank_id = sp.piece_to_id("") params.unk_id = sp.piece_to_id("") params.vocab_size = sp.get_piece_size() @@ -586,9 +643,9 @@ def main(): ) ) else: - assert params.avg > 0 + assert params.avg > 0, params.avg start = params.epoch - params.avg - assert start >= 1 + assert start >= 1, start filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info(