diff --git a/docs/source/recipes/librispeech/conformer_ctc.rst b/docs/source/recipes/librispeech/conformer_ctc.rst index 40100bc5a..a8b0683f4 100644 --- a/docs/source/recipes/librispeech/conformer_ctc.rst +++ b/docs/source/recipes/librispeech/conformer_ctc.rst @@ -299,9 +299,9 @@ The commonly used options are: .. code-block:: $ cd egs/librispeech/ASR - $ ./conformer_ctc/decode.py --method attention-decoder --max-duration 30 --lattice-score-scale 0.5 + $ ./conformer_ctc/decode.py --method attention-decoder --max-duration 30 --nbest-scale 0.5 - - ``--lattice-score-scale`` + - ``--nbest-scale`` It is used to scale down lattice scores so that there are more unique paths for rescoring. @@ -577,7 +577,7 @@ The command to run HLG decoding + LM rescoring + attention decoder rescoring is: --G ./tmp/icefall_asr_librispeech_conformer_ctc/data/lm/G_4_gram.pt \ --ngram-lm-scale 1.3 \ --attention-decoder-scale 1.2 \ - --lattice-score-scale 0.5 \ + --nbest-scale 0.5 \ --num-paths 100 \ --sos-id 1 \ --eos-id 1 \ diff --git a/egs/librispeech/ASR/RESULTS.md b/egs/librispeech/ASR/RESULTS.md index d04e912bf..43a46a30f 100644 --- a/egs/librispeech/ASR/RESULTS.md +++ b/egs/librispeech/ASR/RESULTS.md @@ -40,7 +40,7 @@ python conformer_ctc/train.py --bucketing-sampler True \ --full-libri True \ --world-size 4 -python conformer_ctc/decode.py --lattice-score-scale 0.5 \ +python conformer_ctc/decode.py --nbest-scale 0.5 \ --epoch 34 \ --avg 20 \ --method attention-decoder \ diff --git a/egs/librispeech/ASR/conformer_ctc/decode.py b/egs/librispeech/ASR/conformer_ctc/decode.py index b5b41c82e..5a83dd39c 100755 --- a/egs/librispeech/ASR/conformer_ctc/decode.py +++ b/egs/librispeech/ASR/conformer_ctc/decode.py @@ -23,6 +23,7 @@ from pathlib import Path from typing import Dict, List, Optional, Tuple import k2 +import sentencepiece as spm import torch import torch.nn as nn from asr_datamodule import LibriSpeechAsrDataModule @@ -77,6 +78,9 @@ def get_parser(): default="attention-decoder", help="""Decoding method. Supported values are: + - (0) ctc-decoding. Use CTC decoding. It uses a sentence piece + model, i.e., lang_dir/bpe.model, to convert word pieces to words. + It needs neither a lexicon nor an n-gram LM. - (1) 1best. Extract the best path from the decoding lattice as the decoding result. - (2) nbest. Extract n paths from the decoding lattice; the path @@ -106,7 +110,7 @@ def get_parser(): ) parser.add_argument( - "--lattice-score-scale", + "--nbest-scale", type=float, default=0.5, help="""The scale to be applied to `lattice.scores`. @@ -128,14 +132,26 @@ def get_parser(): """, ) + parser.add_argument( + "--exp-dir", + type=str, + default="conformer_ctc/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--lang-dir", + type=str, + default="data/lang_bpe", + help="The lang dir", + ) + return parser def get_params() -> AttributeDict: params = AttributeDict( { - "exp_dir": Path("conformer_ctc/exp"), - "lang_dir": Path("data/lang_bpe"), "lm_dir": Path("data/lm"), # parameters for conformer "subsampling_factor": 4, @@ -159,13 +175,15 @@ def get_params() -> AttributeDict: def decode_one_batch( params: AttributeDict, model: nn.Module, - HLG: k2.Fsa, + HLG: Optional[k2.Fsa], + H: Optional[k2.Fsa], + bpe_model: Optional[spm.SentencePieceProcessor], batch: dict, word_table: k2.SymbolTable, sos_id: int, eos_id: int, G: Optional[k2.Fsa] = None, -) -> Dict[str, List[List[int]]]: +) -> Dict[str, List[List[str]]]: """Decode one batch and return the result in a dict. The dict has the following format: @@ -190,7 +208,11 @@ def decode_one_batch( model: The neural model. HLG: - The decoding graph. + The decoding graph. Used only when params.method is NOT ctc-decoding. + H: + The ctc topo. Used only when params.method is ctc-decoding. + bpe_model: + The BPE model. Used only when params.method is ctc-decoding. batch: It is the return value from iterating `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation @@ -209,7 +231,10 @@ def decode_one_batch( Return the decoding result. See above description for the format of the returned dict. """ - device = HLG.device + if HLG is not None: + device = HLG.device + else: + device = H.device feature = batch["inputs"] assert feature.ndim == 3 feature = feature.to(device) @@ -229,9 +254,17 @@ def decode_one_batch( 1, ).to(torch.int32) + if H is None: + assert HLG is not None + decoding_graph = HLG + else: + assert HLG is None + assert bpe_model is not None + decoding_graph = H + lattice = get_lattice( nnet_output=nnet_output, - HLG=HLG, + decoding_graph=decoding_graph, supervision_segments=supervision_segments, search_beam=params.search_beam, output_beam=params.output_beam, @@ -240,6 +273,24 @@ def decode_one_batch( subsampling_factor=params.subsampling_factor, ) + if params.method == "ctc-decoding": + best_path = one_best_decoding( + lattice=lattice, use_double_scores=params.use_double_scores + ) + # Note: `best_path.aux_labels` contains token IDs, not word IDs + # since we are using H, not HLG here. + # + # token_ids is a lit-of-list of IDs + token_ids = get_texts(best_path) + + # hyps is a list of str, e.g., ['xxx yyy zzz', ...] + hyps = bpe_model.decode(token_ids) + + # hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ] + hyps = [s.split() for s in hyps] + key = "ctc-decoding" + return {key: hyps} + if params.method == "nbest-oracle": # Note: You can also pass rescored lattices to it. # We choose the HLG decoded lattice for speed reasons @@ -250,12 +301,12 @@ def decode_one_batch( num_paths=params.num_paths, ref_texts=supervisions["text"], word_table=word_table, - lattice_score_scale=params.lattice_score_scale, + nbest_scale=params.nbest_scale, oov="", ) hyps = get_texts(best_path) hyps = [[word_table[i] for i in ids] for ids in hyps] - key = f"oracle_{params.num_paths}_lattice_score_scale_{params.lattice_score_scale}" # noqa + key = f"oracle_{params.num_paths}_nbest_scale_{params.nbest_scale}" # noqa return {key: hyps} if params.method in ["1best", "nbest"]: @@ -269,9 +320,9 @@ def decode_one_batch( lattice=lattice, num_paths=params.num_paths, use_double_scores=params.use_double_scores, - lattice_score_scale=params.lattice_score_scale, + nbest_scale=params.nbest_scale, ) - key = f"no_rescore-scale-{params.lattice_score_scale}-{params.num_paths}" # noqa + key = f"no_rescore-nbest-scale-{params.nbest_scale}-{params.num_paths}" # noqa hyps = get_texts(best_path) hyps = [[word_table[i] for i in ids] for ids in hyps] @@ -293,7 +344,7 @@ def decode_one_batch( G=G, num_paths=params.num_paths, lm_scale_list=lm_scale_list, - lattice_score_scale=params.lattice_score_scale, + nbest_scale=params.nbest_scale, ) elif params.method == "whole-lattice-rescoring": best_path_dict = rescore_with_whole_lattice( @@ -319,7 +370,7 @@ def decode_one_batch( memory_key_padding_mask=memory_key_padding_mask, sos_id=sos_id, eos_id=eos_id, - lattice_score_scale=params.lattice_score_scale, + nbest_scale=params.nbest_scale, ) else: assert False, f"Unsupported decoding method: {params.method}" @@ -340,12 +391,14 @@ def decode_dataset( dl: torch.utils.data.DataLoader, params: AttributeDict, model: nn.Module, - HLG: k2.Fsa, + HLG: Optional[k2.Fsa], + H: Optional[k2.Fsa], + bpe_model: Optional[spm.SentencePieceProcessor], word_table: k2.SymbolTable, sos_id: int, eos_id: int, G: Optional[k2.Fsa] = None, -) -> Dict[str, List[Tuple[List[int], List[int]]]]: +) -> Dict[str, List[Tuple[List[str], List[str]]]]: """Decode dataset. Args: @@ -356,7 +409,11 @@ def decode_dataset( model: The neural model. HLG: - The decoding graph. + The decoding graph. Used only when params.method is NOT ctc-decoding. + H: + The ctc topo. Used only when params.method is ctc-decoding. + bpe_model: + The BPE model. Used only when params.method is ctc-decoding. word_table: It is the word symbol table. sos_id: @@ -391,6 +448,8 @@ def decode_dataset( params=params, model=model, HLG=HLG, + H=H, + bpe_model=bpe_model, batch=batch, word_table=word_table, G=G, @@ -469,6 +528,8 @@ def main(): parser = get_parser() LibriSpeechAsrDataModule.add_arguments(parser) args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + args.lang_dir = Path(args.lang_dir) params = get_params() params.update(vars(args)) @@ -496,14 +557,26 @@ def main(): sos_id = graph_compiler.sos_id eos_id = graph_compiler.eos_id - HLG = k2.Fsa.from_dict( - torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu") - ) - HLG = HLG.to(device) - assert HLG.requires_grad is False + if params.method == "ctc-decoding": + HLG = None + H = k2.ctc_topo( + max_token=max_token_id, + modified=False, + device=device, + ) + bpe_model = spm.SentencePieceProcessor() + bpe_model.load(str(params.lang_dir / "bpe.model")) + else: + H = None + bpe_model = None + HLG = k2.Fsa.from_dict( + torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu") + ) + HLG = HLG.to(device) + assert HLG.requires_grad is False - if not hasattr(HLG, "lm_scores"): - HLG.lm_scores = HLG.scores.clone() + if not hasattr(HLG, "lm_scores"): + HLG.lm_scores = HLG.scores.clone() if params.method in ( "nbest-rescoring", @@ -593,6 +666,8 @@ def main(): params=params, model=model, HLG=HLG, + H=H, + bpe_model=bpe_model, word_table=lexicon.word_table, G=G, sos_id=sos_id, diff --git a/egs/librispeech/ASR/conformer_ctc/pretrained.py b/egs/librispeech/ASR/conformer_ctc/pretrained.py index c924b87bb..00812d674 100755 --- a/egs/librispeech/ASR/conformer_ctc/pretrained.py +++ b/egs/librispeech/ASR/conformer_ctc/pretrained.py @@ -125,7 +125,7 @@ def get_parser(): ) parser.add_argument( - "--lattice-score-scale", + "--nbest-scale", type=float, default=0.5, help=""" @@ -301,7 +301,7 @@ def main(): lattice = get_lattice( nnet_output=nnet_output, - HLG=HLG, + decoding_graph=HLG, supervision_segments=supervision_segments, search_beam=params.search_beam, output_beam=params.output_beam, @@ -336,7 +336,7 @@ def main(): memory_key_padding_mask=memory_key_padding_mask, sos_id=params.sos_id, eos_id=params.eos_id, - lattice_score_scale=params.lattice_score_scale, + nbest_scale=params.nbest_scale, ngram_lm_scale=params.ngram_lm_scale, attention_scale=params.attention_decoder_scale, ) diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py b/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py index 1e91b1008..54c2f7a6b 100755 --- a/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py +++ b/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py @@ -97,7 +97,7 @@ def get_parser(): ) parser.add_argument( - "--lattice-score-scale", + "--nbest-scale", type=float, default=0.5, help="""The scale to be applied to `lattice.scores`. @@ -146,7 +146,7 @@ def decode_one_batch( batch: dict, lexicon: Lexicon, G: Optional[k2.Fsa] = None, -) -> Dict[str, List[List[int]]]: +) -> Dict[str, List[List[str]]]: """Decode one batch and return the result in a dict. The dict has the following format: @@ -210,7 +210,7 @@ def decode_one_batch( lattice = get_lattice( nnet_output=nnet_output, - HLG=HLG, + decoding_graph=HLG, supervision_segments=supervision_segments, search_beam=params.search_beam, output_beam=params.output_beam, @@ -229,7 +229,7 @@ def decode_one_batch( lattice=lattice, num_paths=params.num_paths, use_double_scores=params.use_double_scores, - lattice_score_scale=params.lattice_score_scale, + nbest_scale=params.nbest_scale, ) key = f"no_rescore-{params.num_paths}" hyps = get_texts(best_path) @@ -248,7 +248,7 @@ def decode_one_batch( G=G, num_paths=params.num_paths, lm_scale_list=lm_scale_list, - lattice_score_scale=params.lattice_score_scale, + nbest_scale=params.nbest_scale, ) else: best_path_dict = rescore_with_whole_lattice( @@ -272,7 +272,7 @@ def decode_dataset( HLG: k2.Fsa, lexicon: Lexicon, G: Optional[k2.Fsa] = None, -) -> Dict[str, List[Tuple[List[int], List[int]]]]: +) -> Dict[str, List[Tuple[List[str], List[str]]]]: """Decode dataset. Args: diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/pretrained.py b/egs/librispeech/ASR/tdnn_lstm_ctc/pretrained.py index 0a543d859..2baeb6bba 100755 --- a/egs/librispeech/ASR/tdnn_lstm_ctc/pretrained.py +++ b/egs/librispeech/ASR/tdnn_lstm_ctc/pretrained.py @@ -232,7 +232,7 @@ def main(): lattice = get_lattice( nnet_output=nnet_output, - HLG=HLG, + decoding_graph=HLG, supervision_segments=supervision_segments, search_beam=params.search_beam, output_beam=params.output_beam, diff --git a/egs/yesno/ASR/tdnn/decode.py b/egs/yesno/ASR/tdnn/decode.py index 325acf316..57122235a 100755 --- a/egs/yesno/ASR/tdnn/decode.py +++ b/egs/yesno/ASR/tdnn/decode.py @@ -124,7 +124,7 @@ def decode_one_batch( lattice = get_lattice( nnet_output=nnet_output, - HLG=HLG, + decoding_graph=HLG, supervision_segments=supervision_segments, search_beam=params.search_beam, output_beam=params.output_beam, diff --git a/egs/yesno/ASR/tdnn/pretrained.py b/egs/yesno/ASR/tdnn/pretrained.py index fb92110e3..14220be19 100755 --- a/egs/yesno/ASR/tdnn/pretrained.py +++ b/egs/yesno/ASR/tdnn/pretrained.py @@ -175,7 +175,7 @@ def main(): lattice = get_lattice( nnet_output=nnet_output, - HLG=HLG, + decoding_graph=HLG, supervision_segments=supervision_segments, search_beam=params.search_beam, output_beam=params.output_beam, diff --git a/icefall/decode.py b/icefall/decode.py index e678e4622..62d27dd68 100644 --- a/icefall/decode.py +++ b/icefall/decode.py @@ -66,7 +66,7 @@ def _intersect_device( def get_lattice( nnet_output: torch.Tensor, - HLG: k2.Fsa, + decoding_graph: k2.Fsa, supervision_segments: torch.Tensor, search_beam: float, output_beam: float, @@ -79,8 +79,9 @@ def get_lattice( Args: nnet_output: It is the output of a neural model of shape `(N, T, C)`. - HLG: - An Fsa, the decoding graph. See also `compile_HLG.py`. + decoding_graph: + An Fsa, the decoding graph. It can be either an HLG + (see `compile_HLG.py`) or an H (see `k2.ctc_topo`). supervision_segments: A 2-D **CPU** tensor of dtype `torch.int32` with 3 columns. Each row contains information for a supervision segment. Column 0 @@ -117,7 +118,7 @@ def get_lattice( ) lattice = k2.intersect_dense_pruned( - HLG, + decoding_graph, dense_fsa_vec, search_beam=search_beam, output_beam=output_beam, @@ -180,7 +181,7 @@ class Nbest(object): lattice: k2.Fsa, num_paths: int, use_double_scores: bool = True, - lattice_score_scale: float = 0.5, + nbest_scale: float = 0.5, ) -> "Nbest": """Construct an Nbest object by **sampling** `num_paths` from a lattice. @@ -206,7 +207,7 @@ class Nbest(object): Return an Nbest instance. """ saved_scores = lattice.scores.clone() - lattice.scores *= lattice_score_scale + lattice.scores *= nbest_scale # path is a ragged tensor with dtype torch.int32. # It has three axes [utt][path][arc_pos] path = k2.random_paths( @@ -446,7 +447,7 @@ def nbest_decoding( lattice: k2.Fsa, num_paths: int, use_double_scores: bool = True, - lattice_score_scale: float = 1.0, + nbest_scale: float = 1.0, ) -> k2.Fsa: """It implements something like CTC prefix beam search using n-best lists. @@ -474,7 +475,7 @@ def nbest_decoding( use_double_scores: True to use double precision floating point in the computation. False to use single precision. - lattice_score_scale: + nbest_scale: It's the scale applied to the `lattice.scores`. A smaller value leads to more unique paths at the risk of missing the correct path. Returns: @@ -484,7 +485,7 @@ def nbest_decoding( lattice=lattice, num_paths=num_paths, use_double_scores=use_double_scores, - lattice_score_scale=lattice_score_scale, + nbest_scale=nbest_scale, ) # nbest.fsa.scores contains 0s @@ -505,7 +506,7 @@ def nbest_oracle( ref_texts: List[str], word_table: k2.SymbolTable, use_double_scores: bool = True, - lattice_score_scale: float = 0.5, + nbest_scale: float = 0.5, oov: str = "", ) -> Dict[str, List[List[int]]]: """Select the best hypothesis given a lattice and a reference transcript. @@ -517,7 +518,7 @@ def nbest_oracle( The decoding result returned from this function is the best result that we can obtain using n-best decoding with all kinds of rescoring techniques. - This function is useful to tune the value of `lattice_score_scale`. + This function is useful to tune the value of `nbest_scale`. Args: lattice: @@ -533,7 +534,7 @@ def nbest_oracle( use_double_scores: True to use double precision for computation. False to use single precision. - lattice_score_scale: + nbest_scale: It's the scale applied to the lattice.scores. A smaller value yields more unique paths. oov: @@ -549,7 +550,7 @@ def nbest_oracle( lattice=lattice, num_paths=num_paths, use_double_scores=use_double_scores, - lattice_score_scale=lattice_score_scale, + nbest_scale=nbest_scale, ) hyps = nbest.build_levenshtein_graphs() @@ -590,7 +591,7 @@ def rescore_with_n_best_list( G: k2.Fsa, num_paths: int, lm_scale_list: List[float], - lattice_score_scale: float = 1.0, + nbest_scale: float = 1.0, use_double_scores: bool = True, ) -> Dict[str, k2.Fsa]: """Rescore an n-best list with an n-gram LM. @@ -607,7 +608,7 @@ def rescore_with_n_best_list( Size of nbest list. lm_scale_list: A list of float representing LM score scales. - lattice_score_scale: + nbest_scale: Scale to be applied to ``lattice.score`` when sampling paths using ``k2.random_paths``. use_double_scores: @@ -631,7 +632,7 @@ def rescore_with_n_best_list( lattice=lattice, num_paths=num_paths, use_double_scores=use_double_scores, - lattice_score_scale=lattice_score_scale, + nbest_scale=nbest_scale, ) # nbest.fsa.scores are all 0s at this point @@ -769,7 +770,7 @@ def rescore_with_attention_decoder( memory_key_padding_mask: Optional[torch.Tensor], sos_id: int, eos_id: int, - lattice_score_scale: float = 1.0, + nbest_scale: float = 1.0, ngram_lm_scale: Optional[float] = None, attention_scale: Optional[float] = None, use_double_scores: bool = True, @@ -796,7 +797,7 @@ def rescore_with_attention_decoder( The token ID for SOS. eos_id: The token ID for EOS. - lattice_score_scale: + nbest_scale: It's the scale applied to `lattice.scores`. A smaller value leads to more unique paths at the risk of missing the correct path. ngram_lm_scale: @@ -812,7 +813,7 @@ def rescore_with_attention_decoder( lattice=lattice, num_paths=num_paths, use_double_scores=use_double_scores, - lattice_score_scale=lattice_score_scale, + nbest_scale=nbest_scale, ) # nbest.fsa.scores are all 0s at this point diff --git a/test/test_decode.py b/test/test_decode.py index 7ef127781..97964ac67 100644 --- a/test/test_decode.py +++ b/test/test_decode.py @@ -43,7 +43,7 @@ def test_nbest_from_lattice(): lattice=lattice, num_paths=10, use_double_scores=True, - lattice_score_scale=0.5, + nbest_scale=0.5, ) # each lattice has only 4 distinct paths that have different word sequences: # 10->30