From 707d7017a748f7e53818ad2bb52d1a3780d18d67 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Sun, 26 Sep 2021 14:21:49 +0800 Subject: [PATCH 1/5] Support pure ctc decoding requiring neither a lexicon nor an n-gram LM (#58) * Rename lattice_score_scale to nbest_scale. * Support pure CTC decoding requiring neither a lexicion nor an n-gram LM. * Fix style issues. * Fix a typo. * Minor fixes. --- .../recipes/librispeech/conformer_ctc.rst | 6 +- egs/librispeech/ASR/RESULTS.md | 2 +- egs/librispeech/ASR/conformer_ctc/decode.py | 123 ++++++++++++++---- .../ASR/conformer_ctc/pretrained.py | 6 +- egs/librispeech/ASR/tdnn_lstm_ctc/decode.py | 12 +- .../ASR/tdnn_lstm_ctc/pretrained.py | 2 +- egs/yesno/ASR/tdnn/decode.py | 2 +- egs/yesno/ASR/tdnn/pretrained.py | 2 +- icefall/decode.py | 39 +++--- test/test_decode.py | 2 +- 10 files changed, 136 insertions(+), 60 deletions(-) diff --git a/docs/source/recipes/librispeech/conformer_ctc.rst b/docs/source/recipes/librispeech/conformer_ctc.rst index 40100bc5a..a8b0683f4 100644 --- a/docs/source/recipes/librispeech/conformer_ctc.rst +++ b/docs/source/recipes/librispeech/conformer_ctc.rst @@ -299,9 +299,9 @@ The commonly used options are: .. code-block:: $ cd egs/librispeech/ASR - $ ./conformer_ctc/decode.py --method attention-decoder --max-duration 30 --lattice-score-scale 0.5 + $ ./conformer_ctc/decode.py --method attention-decoder --max-duration 30 --nbest-scale 0.5 - - ``--lattice-score-scale`` + - ``--nbest-scale`` It is used to scale down lattice scores so that there are more unique paths for rescoring. @@ -577,7 +577,7 @@ The command to run HLG decoding + LM rescoring + attention decoder rescoring is: --G ./tmp/icefall_asr_librispeech_conformer_ctc/data/lm/G_4_gram.pt \ --ngram-lm-scale 1.3 \ --attention-decoder-scale 1.2 \ - --lattice-score-scale 0.5 \ + --nbest-scale 0.5 \ --num-paths 100 \ --sos-id 1 \ --eos-id 1 \ diff --git a/egs/librispeech/ASR/RESULTS.md b/egs/librispeech/ASR/RESULTS.md index d04e912bf..43a46a30f 100644 --- a/egs/librispeech/ASR/RESULTS.md +++ b/egs/librispeech/ASR/RESULTS.md @@ -40,7 +40,7 @@ python conformer_ctc/train.py --bucketing-sampler True \ --full-libri True \ --world-size 4 -python conformer_ctc/decode.py --lattice-score-scale 0.5 \ +python conformer_ctc/decode.py --nbest-scale 0.5 \ --epoch 34 \ --avg 20 \ --method attention-decoder \ diff --git a/egs/librispeech/ASR/conformer_ctc/decode.py b/egs/librispeech/ASR/conformer_ctc/decode.py index b5b41c82e..5a83dd39c 100755 --- a/egs/librispeech/ASR/conformer_ctc/decode.py +++ b/egs/librispeech/ASR/conformer_ctc/decode.py @@ -23,6 +23,7 @@ from pathlib import Path from typing import Dict, List, Optional, Tuple import k2 +import sentencepiece as spm import torch import torch.nn as nn from asr_datamodule import LibriSpeechAsrDataModule @@ -77,6 +78,9 @@ def get_parser(): default="attention-decoder", help="""Decoding method. Supported values are: + - (0) ctc-decoding. Use CTC decoding. It uses a sentence piece + model, i.e., lang_dir/bpe.model, to convert word pieces to words. + It needs neither a lexicon nor an n-gram LM. - (1) 1best. Extract the best path from the decoding lattice as the decoding result. - (2) nbest. Extract n paths from the decoding lattice; the path @@ -106,7 +110,7 @@ def get_parser(): ) parser.add_argument( - "--lattice-score-scale", + "--nbest-scale", type=float, default=0.5, help="""The scale to be applied to `lattice.scores`. @@ -128,14 +132,26 @@ def get_parser(): """, ) + parser.add_argument( + "--exp-dir", + type=str, + default="conformer_ctc/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--lang-dir", + type=str, + default="data/lang_bpe", + help="The lang dir", + ) + return parser def get_params() -> AttributeDict: params = AttributeDict( { - "exp_dir": Path("conformer_ctc/exp"), - "lang_dir": Path("data/lang_bpe"), "lm_dir": Path("data/lm"), # parameters for conformer "subsampling_factor": 4, @@ -159,13 +175,15 @@ def get_params() -> AttributeDict: def decode_one_batch( params: AttributeDict, model: nn.Module, - HLG: k2.Fsa, + HLG: Optional[k2.Fsa], + H: Optional[k2.Fsa], + bpe_model: Optional[spm.SentencePieceProcessor], batch: dict, word_table: k2.SymbolTable, sos_id: int, eos_id: int, G: Optional[k2.Fsa] = None, -) -> Dict[str, List[List[int]]]: +) -> Dict[str, List[List[str]]]: """Decode one batch and return the result in a dict. The dict has the following format: @@ -190,7 +208,11 @@ def decode_one_batch( model: The neural model. HLG: - The decoding graph. + The decoding graph. Used only when params.method is NOT ctc-decoding. + H: + The ctc topo. Used only when params.method is ctc-decoding. + bpe_model: + The BPE model. Used only when params.method is ctc-decoding. batch: It is the return value from iterating `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation @@ -209,7 +231,10 @@ def decode_one_batch( Return the decoding result. See above description for the format of the returned dict. """ - device = HLG.device + if HLG is not None: + device = HLG.device + else: + device = H.device feature = batch["inputs"] assert feature.ndim == 3 feature = feature.to(device) @@ -229,9 +254,17 @@ def decode_one_batch( 1, ).to(torch.int32) + if H is None: + assert HLG is not None + decoding_graph = HLG + else: + assert HLG is None + assert bpe_model is not None + decoding_graph = H + lattice = get_lattice( nnet_output=nnet_output, - HLG=HLG, + decoding_graph=decoding_graph, supervision_segments=supervision_segments, search_beam=params.search_beam, output_beam=params.output_beam, @@ -240,6 +273,24 @@ def decode_one_batch( subsampling_factor=params.subsampling_factor, ) + if params.method == "ctc-decoding": + best_path = one_best_decoding( + lattice=lattice, use_double_scores=params.use_double_scores + ) + # Note: `best_path.aux_labels` contains token IDs, not word IDs + # since we are using H, not HLG here. + # + # token_ids is a lit-of-list of IDs + token_ids = get_texts(best_path) + + # hyps is a list of str, e.g., ['xxx yyy zzz', ...] + hyps = bpe_model.decode(token_ids) + + # hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ] + hyps = [s.split() for s in hyps] + key = "ctc-decoding" + return {key: hyps} + if params.method == "nbest-oracle": # Note: You can also pass rescored lattices to it. # We choose the HLG decoded lattice for speed reasons @@ -250,12 +301,12 @@ def decode_one_batch( num_paths=params.num_paths, ref_texts=supervisions["text"], word_table=word_table, - lattice_score_scale=params.lattice_score_scale, + nbest_scale=params.nbest_scale, oov="", ) hyps = get_texts(best_path) hyps = [[word_table[i] for i in ids] for ids in hyps] - key = f"oracle_{params.num_paths}_lattice_score_scale_{params.lattice_score_scale}" # noqa + key = f"oracle_{params.num_paths}_nbest_scale_{params.nbest_scale}" # noqa return {key: hyps} if params.method in ["1best", "nbest"]: @@ -269,9 +320,9 @@ def decode_one_batch( lattice=lattice, num_paths=params.num_paths, use_double_scores=params.use_double_scores, - lattice_score_scale=params.lattice_score_scale, + nbest_scale=params.nbest_scale, ) - key = f"no_rescore-scale-{params.lattice_score_scale}-{params.num_paths}" # noqa + key = f"no_rescore-nbest-scale-{params.nbest_scale}-{params.num_paths}" # noqa hyps = get_texts(best_path) hyps = [[word_table[i] for i in ids] for ids in hyps] @@ -293,7 +344,7 @@ def decode_one_batch( G=G, num_paths=params.num_paths, lm_scale_list=lm_scale_list, - lattice_score_scale=params.lattice_score_scale, + nbest_scale=params.nbest_scale, ) elif params.method == "whole-lattice-rescoring": best_path_dict = rescore_with_whole_lattice( @@ -319,7 +370,7 @@ def decode_one_batch( memory_key_padding_mask=memory_key_padding_mask, sos_id=sos_id, eos_id=eos_id, - lattice_score_scale=params.lattice_score_scale, + nbest_scale=params.nbest_scale, ) else: assert False, f"Unsupported decoding method: {params.method}" @@ -340,12 +391,14 @@ def decode_dataset( dl: torch.utils.data.DataLoader, params: AttributeDict, model: nn.Module, - HLG: k2.Fsa, + HLG: Optional[k2.Fsa], + H: Optional[k2.Fsa], + bpe_model: Optional[spm.SentencePieceProcessor], word_table: k2.SymbolTable, sos_id: int, eos_id: int, G: Optional[k2.Fsa] = None, -) -> Dict[str, List[Tuple[List[int], List[int]]]]: +) -> Dict[str, List[Tuple[List[str], List[str]]]]: """Decode dataset. Args: @@ -356,7 +409,11 @@ def decode_dataset( model: The neural model. HLG: - The decoding graph. + The decoding graph. Used only when params.method is NOT ctc-decoding. + H: + The ctc topo. Used only when params.method is ctc-decoding. + bpe_model: + The BPE model. Used only when params.method is ctc-decoding. word_table: It is the word symbol table. sos_id: @@ -391,6 +448,8 @@ def decode_dataset( params=params, model=model, HLG=HLG, + H=H, + bpe_model=bpe_model, batch=batch, word_table=word_table, G=G, @@ -469,6 +528,8 @@ def main(): parser = get_parser() LibriSpeechAsrDataModule.add_arguments(parser) args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + args.lang_dir = Path(args.lang_dir) params = get_params() params.update(vars(args)) @@ -496,14 +557,26 @@ def main(): sos_id = graph_compiler.sos_id eos_id = graph_compiler.eos_id - HLG = k2.Fsa.from_dict( - torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu") - ) - HLG = HLG.to(device) - assert HLG.requires_grad is False + if params.method == "ctc-decoding": + HLG = None + H = k2.ctc_topo( + max_token=max_token_id, + modified=False, + device=device, + ) + bpe_model = spm.SentencePieceProcessor() + bpe_model.load(str(params.lang_dir / "bpe.model")) + else: + H = None + bpe_model = None + HLG = k2.Fsa.from_dict( + torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu") + ) + HLG = HLG.to(device) + assert HLG.requires_grad is False - if not hasattr(HLG, "lm_scores"): - HLG.lm_scores = HLG.scores.clone() + if not hasattr(HLG, "lm_scores"): + HLG.lm_scores = HLG.scores.clone() if params.method in ( "nbest-rescoring", @@ -593,6 +666,8 @@ def main(): params=params, model=model, HLG=HLG, + H=H, + bpe_model=bpe_model, word_table=lexicon.word_table, G=G, sos_id=sos_id, diff --git a/egs/librispeech/ASR/conformer_ctc/pretrained.py b/egs/librispeech/ASR/conformer_ctc/pretrained.py index c924b87bb..00812d674 100755 --- a/egs/librispeech/ASR/conformer_ctc/pretrained.py +++ b/egs/librispeech/ASR/conformer_ctc/pretrained.py @@ -125,7 +125,7 @@ def get_parser(): ) parser.add_argument( - "--lattice-score-scale", + "--nbest-scale", type=float, default=0.5, help=""" @@ -301,7 +301,7 @@ def main(): lattice = get_lattice( nnet_output=nnet_output, - HLG=HLG, + decoding_graph=HLG, supervision_segments=supervision_segments, search_beam=params.search_beam, output_beam=params.output_beam, @@ -336,7 +336,7 @@ def main(): memory_key_padding_mask=memory_key_padding_mask, sos_id=params.sos_id, eos_id=params.eos_id, - lattice_score_scale=params.lattice_score_scale, + nbest_scale=params.nbest_scale, ngram_lm_scale=params.ngram_lm_scale, attention_scale=params.attention_decoder_scale, ) diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py b/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py index 1e91b1008..54c2f7a6b 100755 --- a/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py +++ b/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py @@ -97,7 +97,7 @@ def get_parser(): ) parser.add_argument( - "--lattice-score-scale", + "--nbest-scale", type=float, default=0.5, help="""The scale to be applied to `lattice.scores`. @@ -146,7 +146,7 @@ def decode_one_batch( batch: dict, lexicon: Lexicon, G: Optional[k2.Fsa] = None, -) -> Dict[str, List[List[int]]]: +) -> Dict[str, List[List[str]]]: """Decode one batch and return the result in a dict. The dict has the following format: @@ -210,7 +210,7 @@ def decode_one_batch( lattice = get_lattice( nnet_output=nnet_output, - HLG=HLG, + decoding_graph=HLG, supervision_segments=supervision_segments, search_beam=params.search_beam, output_beam=params.output_beam, @@ -229,7 +229,7 @@ def decode_one_batch( lattice=lattice, num_paths=params.num_paths, use_double_scores=params.use_double_scores, - lattice_score_scale=params.lattice_score_scale, + nbest_scale=params.nbest_scale, ) key = f"no_rescore-{params.num_paths}" hyps = get_texts(best_path) @@ -248,7 +248,7 @@ def decode_one_batch( G=G, num_paths=params.num_paths, lm_scale_list=lm_scale_list, - lattice_score_scale=params.lattice_score_scale, + nbest_scale=params.nbest_scale, ) else: best_path_dict = rescore_with_whole_lattice( @@ -272,7 +272,7 @@ def decode_dataset( HLG: k2.Fsa, lexicon: Lexicon, G: Optional[k2.Fsa] = None, -) -> Dict[str, List[Tuple[List[int], List[int]]]]: +) -> Dict[str, List[Tuple[List[str], List[str]]]]: """Decode dataset. Args: diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/pretrained.py b/egs/librispeech/ASR/tdnn_lstm_ctc/pretrained.py index 0a543d859..2baeb6bba 100755 --- a/egs/librispeech/ASR/tdnn_lstm_ctc/pretrained.py +++ b/egs/librispeech/ASR/tdnn_lstm_ctc/pretrained.py @@ -232,7 +232,7 @@ def main(): lattice = get_lattice( nnet_output=nnet_output, - HLG=HLG, + decoding_graph=HLG, supervision_segments=supervision_segments, search_beam=params.search_beam, output_beam=params.output_beam, diff --git a/egs/yesno/ASR/tdnn/decode.py b/egs/yesno/ASR/tdnn/decode.py index 325acf316..57122235a 100755 --- a/egs/yesno/ASR/tdnn/decode.py +++ b/egs/yesno/ASR/tdnn/decode.py @@ -124,7 +124,7 @@ def decode_one_batch( lattice = get_lattice( nnet_output=nnet_output, - HLG=HLG, + decoding_graph=HLG, supervision_segments=supervision_segments, search_beam=params.search_beam, output_beam=params.output_beam, diff --git a/egs/yesno/ASR/tdnn/pretrained.py b/egs/yesno/ASR/tdnn/pretrained.py index fb92110e3..14220be19 100755 --- a/egs/yesno/ASR/tdnn/pretrained.py +++ b/egs/yesno/ASR/tdnn/pretrained.py @@ -175,7 +175,7 @@ def main(): lattice = get_lattice( nnet_output=nnet_output, - HLG=HLG, + decoding_graph=HLG, supervision_segments=supervision_segments, search_beam=params.search_beam, output_beam=params.output_beam, diff --git a/icefall/decode.py b/icefall/decode.py index e678e4622..62d27dd68 100644 --- a/icefall/decode.py +++ b/icefall/decode.py @@ -66,7 +66,7 @@ def _intersect_device( def get_lattice( nnet_output: torch.Tensor, - HLG: k2.Fsa, + decoding_graph: k2.Fsa, supervision_segments: torch.Tensor, search_beam: float, output_beam: float, @@ -79,8 +79,9 @@ def get_lattice( Args: nnet_output: It is the output of a neural model of shape `(N, T, C)`. - HLG: - An Fsa, the decoding graph. See also `compile_HLG.py`. + decoding_graph: + An Fsa, the decoding graph. It can be either an HLG + (see `compile_HLG.py`) or an H (see `k2.ctc_topo`). supervision_segments: A 2-D **CPU** tensor of dtype `torch.int32` with 3 columns. Each row contains information for a supervision segment. Column 0 @@ -117,7 +118,7 @@ def get_lattice( ) lattice = k2.intersect_dense_pruned( - HLG, + decoding_graph, dense_fsa_vec, search_beam=search_beam, output_beam=output_beam, @@ -180,7 +181,7 @@ class Nbest(object): lattice: k2.Fsa, num_paths: int, use_double_scores: bool = True, - lattice_score_scale: float = 0.5, + nbest_scale: float = 0.5, ) -> "Nbest": """Construct an Nbest object by **sampling** `num_paths` from a lattice. @@ -206,7 +207,7 @@ class Nbest(object): Return an Nbest instance. """ saved_scores = lattice.scores.clone() - lattice.scores *= lattice_score_scale + lattice.scores *= nbest_scale # path is a ragged tensor with dtype torch.int32. # It has three axes [utt][path][arc_pos] path = k2.random_paths( @@ -446,7 +447,7 @@ def nbest_decoding( lattice: k2.Fsa, num_paths: int, use_double_scores: bool = True, - lattice_score_scale: float = 1.0, + nbest_scale: float = 1.0, ) -> k2.Fsa: """It implements something like CTC prefix beam search using n-best lists. @@ -474,7 +475,7 @@ def nbest_decoding( use_double_scores: True to use double precision floating point in the computation. False to use single precision. - lattice_score_scale: + nbest_scale: It's the scale applied to the `lattice.scores`. A smaller value leads to more unique paths at the risk of missing the correct path. Returns: @@ -484,7 +485,7 @@ def nbest_decoding( lattice=lattice, num_paths=num_paths, use_double_scores=use_double_scores, - lattice_score_scale=lattice_score_scale, + nbest_scale=nbest_scale, ) # nbest.fsa.scores contains 0s @@ -505,7 +506,7 @@ def nbest_oracle( ref_texts: List[str], word_table: k2.SymbolTable, use_double_scores: bool = True, - lattice_score_scale: float = 0.5, + nbest_scale: float = 0.5, oov: str = "", ) -> Dict[str, List[List[int]]]: """Select the best hypothesis given a lattice and a reference transcript. @@ -517,7 +518,7 @@ def nbest_oracle( The decoding result returned from this function is the best result that we can obtain using n-best decoding with all kinds of rescoring techniques. - This function is useful to tune the value of `lattice_score_scale`. + This function is useful to tune the value of `nbest_scale`. Args: lattice: @@ -533,7 +534,7 @@ def nbest_oracle( use_double_scores: True to use double precision for computation. False to use single precision. - lattice_score_scale: + nbest_scale: It's the scale applied to the lattice.scores. A smaller value yields more unique paths. oov: @@ -549,7 +550,7 @@ def nbest_oracle( lattice=lattice, num_paths=num_paths, use_double_scores=use_double_scores, - lattice_score_scale=lattice_score_scale, + nbest_scale=nbest_scale, ) hyps = nbest.build_levenshtein_graphs() @@ -590,7 +591,7 @@ def rescore_with_n_best_list( G: k2.Fsa, num_paths: int, lm_scale_list: List[float], - lattice_score_scale: float = 1.0, + nbest_scale: float = 1.0, use_double_scores: bool = True, ) -> Dict[str, k2.Fsa]: """Rescore an n-best list with an n-gram LM. @@ -607,7 +608,7 @@ def rescore_with_n_best_list( Size of nbest list. lm_scale_list: A list of float representing LM score scales. - lattice_score_scale: + nbest_scale: Scale to be applied to ``lattice.score`` when sampling paths using ``k2.random_paths``. use_double_scores: @@ -631,7 +632,7 @@ def rescore_with_n_best_list( lattice=lattice, num_paths=num_paths, use_double_scores=use_double_scores, - lattice_score_scale=lattice_score_scale, + nbest_scale=nbest_scale, ) # nbest.fsa.scores are all 0s at this point @@ -769,7 +770,7 @@ def rescore_with_attention_decoder( memory_key_padding_mask: Optional[torch.Tensor], sos_id: int, eos_id: int, - lattice_score_scale: float = 1.0, + nbest_scale: float = 1.0, ngram_lm_scale: Optional[float] = None, attention_scale: Optional[float] = None, use_double_scores: bool = True, @@ -796,7 +797,7 @@ def rescore_with_attention_decoder( The token ID for SOS. eos_id: The token ID for EOS. - lattice_score_scale: + nbest_scale: It's the scale applied to `lattice.scores`. A smaller value leads to more unique paths at the risk of missing the correct path. ngram_lm_scale: @@ -812,7 +813,7 @@ def rescore_with_attention_decoder( lattice=lattice, num_paths=num_paths, use_double_scores=use_double_scores, - lattice_score_scale=lattice_score_scale, + nbest_scale=nbest_scale, ) # nbest.fsa.scores are all 0s at this point diff --git a/test/test_decode.py b/test/test_decode.py index 7ef127781..97964ac67 100644 --- a/test/test_decode.py +++ b/test/test_decode.py @@ -43,7 +43,7 @@ def test_nbest_from_lattice(): lattice=lattice, num_paths=10, use_double_scores=True, - lattice_score_scale=0.5, + nbest_scale=0.5, ) # each lattice has only 4 distinct paths that have different word sequences: # 10->30 From adb068eb8242fe79dafce5a100c3fdfad934c7a5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Fri, 1 Oct 2021 04:43:08 -0400 Subject: [PATCH 2/5] setup.py (#64) --- .gitignore | 1 + setup.py | 31 +++++++++++++++++++++++++++++++ 2 files changed, 32 insertions(+) create mode 100644 setup.py diff --git a/.gitignore b/.gitignore index e6c84ca5e..f4f703243 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ +icefall.egg-info/ data __pycache__ path.sh diff --git a/setup.py b/setup.py new file mode 100644 index 000000000..6c720e121 --- /dev/null +++ b/setup.py @@ -0,0 +1,31 @@ +#!/usr/bin/env python3 + +from setuptools import find_packages, setup +from pathlib import Path + +icefall_dir = Path(__file__).parent +install_requires = (icefall_dir / "requirements.txt").read_text().splitlines() + +setup( + name="icefall", + version="1.0", + python_requires=">=3.6.0", + description="Speech processing recipes using k2 and Lhotse.", + author="The k2 and Lhotse Development Team", + license="Apache-2.0 License", + packages=find_packages(), + install_requires=install_requires, + classifiers=[ + "Development Status :: 3 - Alpha", + "Programming Language :: Python :: 3.6", + "Programming Language :: Python :: 3.7", + "Programming Language :: Python :: 3.8", + "Intended Audience :: Science/Research", + "Operating System :: POSIX :: Linux", + "License :: OSI Approved :: Apache Software License", + "Topic :: Multimedia :: Sound/Audio :: Speech", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + "Topic :: Software Development :: Libraries :: Python Modules", + "Typing :: Typed", + ], +) From b682467e4d5b3cc085d7fdb988f61e98979ae6a0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Fri, 8 Oct 2021 22:32:13 -0400 Subject: [PATCH 3/5] Use BucketingSampler for dev and test data --- .../ASR/tdnn_lstm_ctc/asr_datamodule.py | 16 ++++++++-------- egs/yesno/ASR/tdnn/asr_datamodule.py | 17 ++++++++--------- 2 files changed, 16 insertions(+), 17 deletions(-) diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py index 8290e71d1..4953e8538 100644 --- a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py +++ b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py @@ -21,6 +21,10 @@ from functools import lru_cache from pathlib import Path from typing import List, Union +from torch.utils.data import DataLoader + +from icefall.dataset.datamodule import DataModule +from icefall.utils import str2bool from lhotse import CutSet, Fbank, FbankConfig, load_manifest from lhotse.dataset import ( BucketingSampler, @@ -32,10 +36,6 @@ from lhotse.dataset import ( SpecAugment, ) from lhotse.dataset.input_strategies import OnTheFlyFeatures -from torch.utils.data import DataLoader - -from icefall.dataset.datamodule import DataModule -from icefall.utils import str2bool class LibriSpeechAsrDataModule(DataModule): @@ -267,7 +267,7 @@ class LibriSpeechAsrDataModule(DataModule): cut_transforms=transforms, return_cuts=self.args.return_cuts, ) - valid_sampler = SingleCutSampler( + valid_sampler = BucketingSampler( cuts_valid, max_duration=self.args.max_duration, shuffle=False, @@ -300,12 +300,12 @@ class LibriSpeechAsrDataModule(DataModule): else PrecomputedFeatures(), return_cuts=self.args.return_cuts, ) - sampler = SingleCutSampler( - cuts_test, max_duration=self.args.max_duration + sampler = BucketingSampler( + cuts_test, max_duration=self.args.max_duration, shuffle=False ) logging.debug("About to create test dataloader") test_dl = DataLoader( - test, batch_size=None, sampler=sampler, num_workers=1 + test, batch_size=None, sampler=sampler, num_workers=self.args.num_workers ) test_loaders.append(test_dl) diff --git a/egs/yesno/ASR/tdnn/asr_datamodule.py b/egs/yesno/ASR/tdnn/asr_datamodule.py index e6614e3ce..a9a6145f0 100644 --- a/egs/yesno/ASR/tdnn/asr_datamodule.py +++ b/egs/yesno/ASR/tdnn/asr_datamodule.py @@ -20,19 +20,18 @@ from functools import lru_cache from pathlib import Path from typing import List +from torch.utils.data import DataLoader + +from icefall.dataset.datamodule import DataModule +from icefall.utils import str2bool from lhotse import CutSet, Fbank, FbankConfig, load_manifest from lhotse.dataset import ( BucketingSampler, CutConcatenate, K2SpeechRecognitionDataset, PrecomputedFeatures, - SingleCutSampler, ) from lhotse.dataset.input_strategies import OnTheFlyFeatures -from torch.utils.data import DataLoader - -from icefall.dataset.datamodule import DataModule -from icefall.utils import str2bool class YesNoAsrDataModule(DataModule): @@ -198,7 +197,7 @@ class YesNoAsrDataModule(DataModule): ) else: logging.info("Using SingleCutSampler.") - train_sampler = SingleCutSampler( + train_sampler = BucketingSampler( cuts_train, max_duration=self.args.max_duration, shuffle=self.args.shuffle, @@ -226,12 +225,12 @@ class YesNoAsrDataModule(DataModule): else PrecomputedFeatures(), return_cuts=self.args.return_cuts, ) - sampler = SingleCutSampler( - cuts_test, max_duration=self.args.max_duration + sampler = BucketingSampler( + cuts_test, max_duration=self.args.max_duration, shuffle=False ) logging.debug("About to create test dataloader") test_dl = DataLoader( - test, batch_size=None, sampler=sampler, num_workers=1 + test, batch_size=None, sampler=sampler, num_workers=self.args.num_workers ) return test_dl From 6e43905d124566dc0779e9c900b0a309c21850f0 Mon Sep 17 00:00:00 2001 From: Mingshuang Luo <37799481+luomingshuang@users.noreply.github.com> Date: Sat, 9 Oct 2021 11:56:25 +0800 Subject: [PATCH 4/5] Update the documentation to include "ctc-decoding" (#71) * Update conformer_ctc.rst --- .../recipes/librispeech/conformer_ctc.rst | 68 ++++++++++++++++++- 1 file changed, 66 insertions(+), 2 deletions(-) diff --git a/docs/source/recipes/librispeech/conformer_ctc.rst b/docs/source/recipes/librispeech/conformer_ctc.rst index a8b0683f4..73c5503d8 100644 --- a/docs/source/recipes/librispeech/conformer_ctc.rst +++ b/docs/source/recipes/librispeech/conformer_ctc.rst @@ -292,9 +292,18 @@ The commonly used options are: - ``--method`` - This specifies the decoding method. + This specifies the decoding method. This script supports 7 decoding methods. + As for ctc decoding, it uses a sentence piece model to convert word pieces to words. + And it needs neither a lexicon nor an n-gram LM. + + For example, the following command uses CTC topology for decoding: + + .. code-block:: - The following command uses attention decoder for rescoring: + $ cd egs/librispeech/ASR + $ ./conformer_ctc/decode.py --method ctc-decoding --max-duration 300 + + And the following command uses attention decoder for rescoring: .. code-block:: @@ -311,6 +320,61 @@ The commonly used options are: It has the same meaning as the one during training. A larger value may cause OOM. +Here are some results for CTC decoding with a vocab size of 500: + +Usage: + +.. code-block:: bash + + $ cd egs/librispeech/ASR + $ ./conformer_ctc/decode.py \ + --epoch 25 \ + --avg 1 \ + --max-duration 300 \ + --exp-dir conformer_ctc/exp \ + --lang-dir data/lang_bpe_500 \ + --method ctc-decoding + +The output is given below: + +.. code-block:: bash + + 2021-09-26 12:44:31,033 INFO [decode.py:537] Decoding started + 2021-09-26 12:44:31,033 INFO [decode.py:538] + {'lm_dir': PosixPath('data/lm'), 'subsampling_factor': 4, 'vgg_frontend': False, 'use_feat_batchnorm': True, + 'feature_dim': 80, 'nhead': 8, 'attention_dim': 512, 'num_decoder_layers': 6, 'search_beam': 20, 'output_beam': 8, + 'min_active_states': 30, 'max_active_states': 10000, 'use_double_scores': True, + 'epoch': 25, 'avg': 1, 'method': 'ctc-decoding', 'num_paths': 100, 'nbest_scale': 0.5, + 'export': False, 'exp_dir': PosixPath('conformer_ctc/exp'), 'lang_dir': PosixPath('data/lang_bpe_500'), 'full_libri': False, + 'feature_dir': PosixPath('data/fbank'), 'max_duration': 100, 'bucketing_sampler': False, 'num_buckets': 30, + 'concatenate_cuts': False, 'duration_factor': 1.0, 'gap': 1.0, 'on_the_fly_feats': False, + 'shuffle': True, 'return_cuts': True, 'num_workers': 2} + 2021-09-26 12:44:31,406 INFO [lexicon.py:113] Loading pre-compiled data/lang_bpe_500/Linv.pt + 2021-09-26 12:44:31,464 INFO [decode.py:548] device: cuda:0 + 2021-09-26 12:44:36,171 INFO [checkpoint.py:92] Loading checkpoint from conformer_ctc/exp/epoch-25.pt + 2021-09-26 12:44:36,776 INFO [decode.py:652] Number of model parameters: 109226120 + 2021-09-26 12:44:37,714 INFO [decode.py:473] batch 0/206, cuts processed until now is 12 + 2021-09-26 12:45:15,944 INFO [decode.py:473] batch 100/206, cuts processed until now is 1328 + 2021-09-26 12:45:54,443 INFO [decode.py:473] batch 200/206, cuts processed until now is 2563 + 2021-09-26 12:45:56,411 INFO [decode.py:494] The transcripts are stored in conformer_ctc/exp/recogs-test-clean-ctc-decoding.txt + 2021-09-26 12:45:56,592 INFO [utils.py:331] [test-clean-ctc-decoding] %WER 3.26% [1715 / 52576, 163 ins, 128 del, 1424 sub ] + 2021-09-26 12:45:56,807 INFO [decode.py:506] Wrote detailed error stats to conformer_ctc/exp/errs-test-clean-ctc-decoding.txt + 2021-09-26 12:45:56,808 INFO [decode.py:522] + For test-clean, WER of different settings are: + ctc-decoding 3.26 best for test-clean + + 2021-09-26 12:45:57,362 INFO [decode.py:473] batch 0/203, cuts processed until now is 15 + 2021-09-26 12:46:35,565 INFO [decode.py:473] batch 100/203, cuts processed until now is 1477 + 2021-09-26 12:47:15,106 INFO [decode.py:473] batch 200/203, cuts processed until now is 2922 + 2021-09-26 12:47:16,131 INFO [decode.py:494] The transcripts are stored in conformer_ctc/exp/recogs-test-other-ctc-decoding.txt + 2021-09-26 12:47:16,208 INFO [utils.py:331] [test-other-ctc-decoding] %WER 8.21% [4295 / 52343, 396 ins, 315 del, 3584 sub ] + 2021-09-26 12:47:16,432 INFO [decode.py:506] Wrote detailed error stats to conformer_ctc/exp/errs-test-other-ctc-decoding.txt + 2021-09-26 12:47:16,432 INFO [decode.py:522] + For test-other, WER of different settings are: + ctc-decoding 8.21 best for test-other + + 2021-09-26 12:47:16,433 INFO [decode.py:680] Done! + Pre-trained Model ----------------- From 069ebaf9bab80209359b5ed19a3756cde2a551f7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Sat, 9 Oct 2021 14:45:46 +0000 Subject: [PATCH 5/5] Reformatting --- egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py | 5 ++++- egs/yesno/ASR/tdnn/asr_datamodule.py | 5 ++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py index 4953e8538..229575db6 100644 --- a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py +++ b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py @@ -305,7 +305,10 @@ class LibriSpeechAsrDataModule(DataModule): ) logging.debug("About to create test dataloader") test_dl = DataLoader( - test, batch_size=None, sampler=sampler, num_workers=self.args.num_workers + test, + batch_size=None, + sampler=sampler, + num_workers=self.args.num_workers, ) test_loaders.append(test_dl) diff --git a/egs/yesno/ASR/tdnn/asr_datamodule.py b/egs/yesno/ASR/tdnn/asr_datamodule.py index a9a6145f0..832fd556e 100644 --- a/egs/yesno/ASR/tdnn/asr_datamodule.py +++ b/egs/yesno/ASR/tdnn/asr_datamodule.py @@ -230,7 +230,10 @@ class YesNoAsrDataModule(DataModule): ) logging.debug("About to create test dataloader") test_dl = DataLoader( - test, batch_size=None, sampler=sampler, num_workers=self.args.num_workers + test, + batch_size=None, + sampler=sampler, + num_workers=self.args.num_workers, ) return test_dl