From 76c311caf47832deb2bbc9d8e4ca65d7f295231c Mon Sep 17 00:00:00 2001 From: jinzr <60612200+JinZr@users.noreply.github.com> Date: Tue, 17 Oct 2023 16:40:41 +0800 Subject: [PATCH] Update ctc_decode.py --- egs/multi_zh-hans/ASR/zipformer/ctc_decode.py | 101 +----------------- 1 file changed, 4 insertions(+), 97 deletions(-) diff --git a/egs/multi_zh-hans/ASR/zipformer/ctc_decode.py b/egs/multi_zh-hans/ASR/zipformer/ctc_decode.py index 0036bb5bd..5482c4fae 100755 --- a/egs/multi_zh-hans/ASR/zipformer/ctc_decode.py +++ b/egs/multi_zh-hans/ASR/zipformer/ctc_decode.py @@ -54,14 +54,7 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from icefall.decode import ( - get_lattice, - nbest_decoding, - nbest_oracle, - one_best_decoding, - rescore_with_n_best_list, - rescore_with_whole_lattice, -) +from icefall.decode import get_lattice, one_best_decoding from icefall.lexicon import Lexicon from icefall.utils import ( AttributeDict, @@ -204,12 +197,9 @@ def get_decoding_params() -> AttributeDict: def decode_one_batch( params: AttributeDict, model: nn.Module, - HLG: Optional[k2.Fsa], H: Optional[k2.Fsa], bpe_model: Optional[spm.SentencePieceProcessor], batch: dict, - word_table: k2.SymbolTable, - G: Optional[k2.Fsa] = None, ) -> Dict[str, List[List[str]]]: """Decode one batch and return the result in a dict. The dict has the following format: @@ -234,8 +224,6 @@ def decode_one_batch( model: The neural model. - HLG: - The decoding graph. Used only when params.decoding_method is NOT ctc-decoding. H: The ctc topo. Used only when params.decoding_method is ctc-decoding. bpe_model: @@ -254,10 +242,7 @@ def decode_one_batch( Return the decoding result. See above description for the format of the returned dict. Note: If it decodes to nothing, then return None. """ - if HLG is not None: - device = HLG.device - else: - device = H.device + device = H.device feature = batch["inputs"] assert feature.ndim == 3 feature = feature.to(device) @@ -296,13 +281,8 @@ def decode_one_batch( 1, ).to(torch.int32) - if H is None: - assert HLG is not None - decoding_graph = HLG - else: - assert HLG is None - assert bpe_model is not None - decoding_graph = H + assert bpe_model is not None + decoding_graph = H lattice = get_lattice( nnet_output=ctc_output, @@ -333,79 +313,6 @@ def decode_one_batch( key = "ctc-decoding" return {key: hyps} - if params.decoding_method == "nbest-oracle": - # Note: You can also pass rescored lattices to it. - # We choose the HLG decoded lattice for speed reasons - # as HLG decoding is faster and the oracle WER - # is only slightly worse than that of rescored lattices. - best_path = nbest_oracle( - lattice=lattice, - num_paths=params.num_paths, - ref_texts=supervisions["text"], - word_table=word_table, - nbest_scale=params.nbest_scale, - oov="", - ) - hyps = get_texts(best_path) - hyps = [[word_table[i] for i in ids] for ids in hyps] - key = f"oracle_{params.num_paths}_nbest_scale_{params.nbest_scale}" # noqa - return {key: hyps} - - if params.decoding_method in ["1best", "nbest"]: - if params.decoding_method == "1best": - best_path = one_best_decoding( - lattice=lattice, use_double_scores=params.use_double_scores - ) - key = "no_rescore" - else: - best_path = nbest_decoding( - lattice=lattice, - num_paths=params.num_paths, - use_double_scores=params.use_double_scores, - nbest_scale=params.nbest_scale, - ) - key = f"no_rescore-nbest-scale-{params.nbest_scale}-{params.num_paths}" # noqa - - hyps = get_texts(best_path) - hyps = [[word_table[i] for i in ids] for ids in hyps] - return {key: hyps} - - assert params.decoding_method in [ - "nbest-rescoring", - "whole-lattice-rescoring", - ] - - lm_scale_list = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7] - lm_scale_list += [0.8, 0.9, 1.0, 1.1, 1.2, 1.3] - lm_scale_list += [1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0] - - if params.decoding_method == "nbest-rescoring": - best_path_dict = rescore_with_n_best_list( - lattice=lattice, - G=G, - num_paths=params.num_paths, - lm_scale_list=lm_scale_list, - nbest_scale=params.nbest_scale, - ) - elif params.decoding_method == "whole-lattice-rescoring": - best_path_dict = rescore_with_whole_lattice( - lattice=lattice, - G_with_epsilon_loops=G, - lm_scale_list=lm_scale_list, - ) - else: - assert False, f"Unsupported decoding method: {params.decoding_method}" - - ans = dict() - if best_path_dict is not None: - for lm_scale_str, best_path in best_path_dict.items(): - hyps = get_texts(best_path) - hyps = [[word_table[i] for i in ids] for ids in hyps] - ans[lm_scale_str] = hyps - else: - ans = None - return ans - def decode_dataset( dl: torch.utils.data.DataLoader,