From 9b13eac94686fe4a31eb190b8deaf3ff967feacd Mon Sep 17 00:00:00 2001 From: root Date: Wed, 14 Aug 2024 15:47:28 +0900 Subject: [PATCH] fix misc line --- egs/librispeech/ASR/zipformer/decode.py | 235 ++++++++++++------------ 1 file changed, 117 insertions(+), 118 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/decode.py b/egs/librispeech/ASR/zipformer/decode.py index cfe5638b7..52a489eb3 100755 --- a/egs/librispeech/ASR/zipformer/decode.py +++ b/egs/librispeech/ASR/zipformer/decode.py @@ -371,7 +371,6 @@ def get_parser(): modified_beam_search_LODR. """, ) -<<<<<<< HEAD parser.add_argument( "--skip-scoring", @@ -631,9 +630,9 @@ def decode_one_batch( elif "modified_beam_search" in params.decoding_method: prefix += f"_beam-size-{params.beam_size}" if params.decoding_method in ( - "modified_beam_search_lm_rescore", - "modified_beam_search_lm_rescore_LODR", - ): + "modified_beam_search_lm_rescore", + "modified_beam_search_lm_rescore_LODR", + ): ans = dict() assert ans_dict is not None for key, hyps in ans_dict.items(): @@ -650,17 +649,17 @@ def decode_one_batch( def decode_dataset( - dl: torch.utils.data.DataLoader, - params: AttributeDict, - model: nn.Module, - sp: spm.SentencePieceProcessor, - word_table: Optional[k2.SymbolTable] = None, - decoding_graph: Optional[k2.Fsa] = None, - context_graph: Optional[ContextGraph] = None, - LM: Optional[LmScorer] = None, - ngram_lm=None, - ngram_lm_scale: float = 0.0, - ) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + word_table: Optional[k2.SymbolTable] = None, + decoding_graph: Optional[k2.Fsa] = None, + context_graph: Optional[ContextGraph] = None, + LM: Optional[LmScorer] = None, + ngram_lm=None, + ngram_lm_scale: float = 0.0, +) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: """Decode dataset. Args: @@ -703,17 +702,17 @@ def decode_dataset( cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] hyps_dict = decode_one_batch( - params=params, - model=model, - sp=sp, - decoding_graph=decoding_graph, - context_graph=context_graph, - word_table=word_table, - batch=batch, - LM=LM, - ngram_lm=ngram_lm, - ngram_lm_scale=ngram_lm_scale, - ) + params=params, + model=model, + sp=sp, + decoding_graph=decoding_graph, + context_graph=context_graph, + word_table=word_table, + batch=batch, + LM=LM, + ngram_lm=ngram_lm, + ngram_lm_scale=ngram_lm_scale, + ) for name, hyps in hyps_dict.items(): this_batch = [] @@ -734,10 +733,10 @@ def decode_dataset( def save_asr_output( - params: AttributeDict, - test_set_name: str, - results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], - ): + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], +): """ Save text produced by ASR. """ @@ -752,10 +751,10 @@ def save_asr_output( def save_wer_results( - params: AttributeDict, - test_set_name: str, - results_dict: Dict[str, List[Tuple[str, List[str], List[str], Tuple]]], - ): + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[str, List[str], List[str], Tuple]]], +): """ Save WER and per-utterance word alignments. """ @@ -766,8 +765,8 @@ def save_wer_results( errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" with open(errs_filename, "w", encoding="utf8") as fd: wer = write_error_stats( - fd, f"{test_set_name}-{key}", results, enable_log=True - ) + fd, f"{test_set_name}-{key}", results, enable_log=True + ) test_set_wers[key] = wer logging.info(f"Wrote detailed error stats to {errs_filename}") @@ -804,18 +803,18 @@ def main(): set_caching_enabled(True) # lhotse assert params.decoding_method in ( - "greedy_search", - "beam_search", - "fast_beam_search", - "fast_beam_search_nbest", - "fast_beam_search_nbest_LG", - "fast_beam_search_nbest_oracle", - "modified_beam_search", - "modified_beam_search_LODR", - "modified_beam_search_lm_shallow_fusion", - "modified_beam_search_lm_rescore", - "modified_beam_search_lm_rescore_LODR", - ) + "greedy_search", + "beam_search", + "fast_beam_search", + "fast_beam_search_nbest", + "fast_beam_search_nbest_LG", + "fast_beam_search_nbest_oracle", + "modified_beam_search", + "modified_beam_search_LODR", + "modified_beam_search_lm_shallow_fusion", + "modified_beam_search_lm_rescore", + "modified_beam_search_lm_rescore_LODR", + ) params.res_dir = params.exp_dir / params.decoding_method if os.path.exists(params.context_file): @@ -830,11 +829,11 @@ def main(): if params.causal: assert ( - "," not in params.chunk_size - ), "chunk_size should be one value in decoding." + "," not in params.chunk_size + ), "chunk_size should be one value in decoding." assert ( - "," not in params.left_context_frames - ), "left_context_frames should be one value in decoding." + "," not in params.left_context_frames + ), "left_context_frames should be one value in decoding." params.suffix += f"_chunk-{params.chunk_size}" params.suffix += f"_left-context-{params.left_context_frames}" @@ -850,9 +849,9 @@ def main(): elif "beam_search" in params.decoding_method: params.suffix += f"__{params.decoding_method}__beam-size-{params.beam_size}" if params.decoding_method in ( - "modified_beam_search", - "modified_beam_search_LODR", - ): + "modified_beam_search", + "modified_beam_search_LODR", + ): if params.has_contexts: params.suffix += f"-context-score-{params.context_score}" else: @@ -864,8 +863,8 @@ def main(): if "LODR" in params.decoding_method: params.suffix += ( - f"_LODR-{params.tokens_ngram}gram-scale-{params.ngram_lm_scale}" - ) + f"_LODR-{params.tokens_ngram}gram-scale-{params.ngram_lm_scale}" + ) if params.use_averaged_model: params.suffix += "_use-averaged-model" @@ -895,18 +894,18 @@ def main(): if not params.use_averaged_model: if params.iter > 0: filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] + : params.avg + ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) elif len(filenames) < params.avg: raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) logging.info(f"averaging {filenames}") model.to(device) model.load_state_dict(average_checkpoints(filenames, device=device)) @@ -924,32 +923,32 @@ def main(): else: if params.iter > 0: filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) elif len(filenames) < params.avg + 1: raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) filename_start = filenames[-1] filename_end = filenames[0] logging.info( - "Calculating the averaged model over iteration checkpoints" - f" from {filename_start} (excluded) to {filename_end}" - ) + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) model.to(device) model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) else: assert params.avg > 0, params.avg start = params.epoch - params.avg @@ -957,34 +956,34 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - f"Calculating the averaged model over epoch range from " - f"{start} (excluded) to {params.epoch}" - ) + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) model.to(device) model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) model.to(device) model.eval() # only load the neural network LM if required if params.use_shallow_fusion or params.decoding_method in ( - "modified_beam_search_lm_rescore", - "modified_beam_search_lm_rescore_LODR", - "modified_beam_search_lm_shallow_fusion", - "modified_beam_search_LODR", - ): + "modified_beam_search_lm_rescore", + "modified_beam_search_lm_rescore_LODR", + "modified_beam_search_lm_shallow_fusion", + "modified_beam_search_LODR", + ): LM = LmScorer( - lm_type=params.lm_type, - params=params, - device=device, - lm_scale=params.lm_scale, - ) + lm_type=params.lm_type, + params=params, + device=device, + lm_scale=params.lm_scale, + ) LM.to(device) LM.eval() else: @@ -1010,10 +1009,10 @@ def main(): lm_filename = f"{params.tokens_ngram}gram.fst.txt" logging.info(f"Loading token level lm: {lm_filename}") ngram_lm = NgramLm( - str(params.lang_dir / lm_filename), - backoff_id=params.backoff_id, - is_binary=False, - ) + str(params.lang_dir / lm_filename), + backoff_id=params.backoff_id, + is_binary=False, + ) logging.info(f"num states: {ngram_lm.lm.num_states}") ngram_lm_scale = params.ngram_lm_scale else: @@ -1027,8 +1026,8 @@ def main(): lg_filename = params.lang_dir / "LG.pt" logging.info(f"Loading {lg_filename}") decoding_graph = k2.Fsa.from_dict( - torch.load(lg_filename, map_location=device) - ) + torch.load(lg_filename, map_location=device) + ) decoding_graph.scores *= params.ngram_lm_scale else: word_table = None @@ -1067,17 +1066,17 @@ def main(): for test_set, test_dl in zip(test_sets, test_dl): results_dict = decode_dataset( - dl=test_dl, - params=params, - model=model, - sp=sp, - word_table=word_table, - decoding_graph=decoding_graph, - context_graph=context_graph, - LM=LM, - ngram_lm=ngram_lm, - ngram_lm_scale=ngram_lm_scale, - ) + dl=test_dl, + params=params, + model=model, + sp=sp, + word_table=word_table, + decoding_graph=decoding_graph, + context_graph=context_graph, + LM=LM, + ngram_lm=ngram_lm, + ngram_lm_scale=ngram_lm_scale, + ) save_asr_output( params=params,