From 707d7017a748f7e53818ad2bb52d1a3780d18d67 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Sun, 26 Sep 2021 14:21:49 +0800 Subject: [PATCH 01/14] Support pure ctc decoding requiring neither a lexicon nor an n-gram LM (#58) * Rename lattice_score_scale to nbest_scale. * Support pure CTC decoding requiring neither a lexicion nor an n-gram LM. * Fix style issues. * Fix a typo. * Minor fixes. --- .../recipes/librispeech/conformer_ctc.rst | 6 +- egs/librispeech/ASR/RESULTS.md | 2 +- egs/librispeech/ASR/conformer_ctc/decode.py | 123 ++++++++++++++---- .../ASR/conformer_ctc/pretrained.py | 6 +- egs/librispeech/ASR/tdnn_lstm_ctc/decode.py | 12 +- .../ASR/tdnn_lstm_ctc/pretrained.py | 2 +- egs/yesno/ASR/tdnn/decode.py | 2 +- egs/yesno/ASR/tdnn/pretrained.py | 2 +- icefall/decode.py | 39 +++--- test/test_decode.py | 2 +- 10 files changed, 136 insertions(+), 60 deletions(-) diff --git a/docs/source/recipes/librispeech/conformer_ctc.rst b/docs/source/recipes/librispeech/conformer_ctc.rst index 40100bc5a..a8b0683f4 100644 --- a/docs/source/recipes/librispeech/conformer_ctc.rst +++ b/docs/source/recipes/librispeech/conformer_ctc.rst @@ -299,9 +299,9 @@ The commonly used options are: .. code-block:: $ cd egs/librispeech/ASR - $ ./conformer_ctc/decode.py --method attention-decoder --max-duration 30 --lattice-score-scale 0.5 + $ ./conformer_ctc/decode.py --method attention-decoder --max-duration 30 --nbest-scale 0.5 - - ``--lattice-score-scale`` + - ``--nbest-scale`` It is used to scale down lattice scores so that there are more unique paths for rescoring. @@ -577,7 +577,7 @@ The command to run HLG decoding + LM rescoring + attention decoder rescoring is: --G ./tmp/icefall_asr_librispeech_conformer_ctc/data/lm/G_4_gram.pt \ --ngram-lm-scale 1.3 \ --attention-decoder-scale 1.2 \ - --lattice-score-scale 0.5 \ + --nbest-scale 0.5 \ --num-paths 100 \ --sos-id 1 \ --eos-id 1 \ diff --git a/egs/librispeech/ASR/RESULTS.md b/egs/librispeech/ASR/RESULTS.md index d04e912bf..43a46a30f 100644 --- a/egs/librispeech/ASR/RESULTS.md +++ b/egs/librispeech/ASR/RESULTS.md @@ -40,7 +40,7 @@ python conformer_ctc/train.py --bucketing-sampler True \ --full-libri True \ --world-size 4 -python conformer_ctc/decode.py --lattice-score-scale 0.5 \ +python conformer_ctc/decode.py --nbest-scale 0.5 \ --epoch 34 \ --avg 20 \ --method attention-decoder \ diff --git a/egs/librispeech/ASR/conformer_ctc/decode.py b/egs/librispeech/ASR/conformer_ctc/decode.py index b5b41c82e..5a83dd39c 100755 --- a/egs/librispeech/ASR/conformer_ctc/decode.py +++ b/egs/librispeech/ASR/conformer_ctc/decode.py @@ -23,6 +23,7 @@ from pathlib import Path from typing import Dict, List, Optional, Tuple import k2 +import sentencepiece as spm import torch import torch.nn as nn from asr_datamodule import LibriSpeechAsrDataModule @@ -77,6 +78,9 @@ def get_parser(): default="attention-decoder", help="""Decoding method. Supported values are: + - (0) ctc-decoding. Use CTC decoding. It uses a sentence piece + model, i.e., lang_dir/bpe.model, to convert word pieces to words. + It needs neither a lexicon nor an n-gram LM. - (1) 1best. Extract the best path from the decoding lattice as the decoding result. - (2) nbest. Extract n paths from the decoding lattice; the path @@ -106,7 +110,7 @@ def get_parser(): ) parser.add_argument( - "--lattice-score-scale", + "--nbest-scale", type=float, default=0.5, help="""The scale to be applied to `lattice.scores`. @@ -128,14 +132,26 @@ def get_parser(): """, ) + parser.add_argument( + "--exp-dir", + type=str, + default="conformer_ctc/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--lang-dir", + type=str, + default="data/lang_bpe", + help="The lang dir", + ) + return parser def get_params() -> AttributeDict: params = AttributeDict( { - "exp_dir": Path("conformer_ctc/exp"), - "lang_dir": Path("data/lang_bpe"), "lm_dir": Path("data/lm"), # parameters for conformer "subsampling_factor": 4, @@ -159,13 +175,15 @@ def get_params() -> AttributeDict: def decode_one_batch( params: AttributeDict, model: nn.Module, - HLG: k2.Fsa, + HLG: Optional[k2.Fsa], + H: Optional[k2.Fsa], + bpe_model: Optional[spm.SentencePieceProcessor], batch: dict, word_table: k2.SymbolTable, sos_id: int, eos_id: int, G: Optional[k2.Fsa] = None, -) -> Dict[str, List[List[int]]]: +) -> Dict[str, List[List[str]]]: """Decode one batch and return the result in a dict. The dict has the following format: @@ -190,7 +208,11 @@ def decode_one_batch( model: The neural model. HLG: - The decoding graph. + The decoding graph. Used only when params.method is NOT ctc-decoding. + H: + The ctc topo. Used only when params.method is ctc-decoding. + bpe_model: + The BPE model. Used only when params.method is ctc-decoding. batch: It is the return value from iterating `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation @@ -209,7 +231,10 @@ def decode_one_batch( Return the decoding result. See above description for the format of the returned dict. """ - device = HLG.device + if HLG is not None: + device = HLG.device + else: + device = H.device feature = batch["inputs"] assert feature.ndim == 3 feature = feature.to(device) @@ -229,9 +254,17 @@ def decode_one_batch( 1, ).to(torch.int32) + if H is None: + assert HLG is not None + decoding_graph = HLG + else: + assert HLG is None + assert bpe_model is not None + decoding_graph = H + lattice = get_lattice( nnet_output=nnet_output, - HLG=HLG, + decoding_graph=decoding_graph, supervision_segments=supervision_segments, search_beam=params.search_beam, output_beam=params.output_beam, @@ -240,6 +273,24 @@ def decode_one_batch( subsampling_factor=params.subsampling_factor, ) + if params.method == "ctc-decoding": + best_path = one_best_decoding( + lattice=lattice, use_double_scores=params.use_double_scores + ) + # Note: `best_path.aux_labels` contains token IDs, not word IDs + # since we are using H, not HLG here. + # + # token_ids is a lit-of-list of IDs + token_ids = get_texts(best_path) + + # hyps is a list of str, e.g., ['xxx yyy zzz', ...] + hyps = bpe_model.decode(token_ids) + + # hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ] + hyps = [s.split() for s in hyps] + key = "ctc-decoding" + return {key: hyps} + if params.method == "nbest-oracle": # Note: You can also pass rescored lattices to it. # We choose the HLG decoded lattice for speed reasons @@ -250,12 +301,12 @@ def decode_one_batch( num_paths=params.num_paths, ref_texts=supervisions["text"], word_table=word_table, - lattice_score_scale=params.lattice_score_scale, + nbest_scale=params.nbest_scale, oov="", ) hyps = get_texts(best_path) hyps = [[word_table[i] for i in ids] for ids in hyps] - key = f"oracle_{params.num_paths}_lattice_score_scale_{params.lattice_score_scale}" # noqa + key = f"oracle_{params.num_paths}_nbest_scale_{params.nbest_scale}" # noqa return {key: hyps} if params.method in ["1best", "nbest"]: @@ -269,9 +320,9 @@ def decode_one_batch( lattice=lattice, num_paths=params.num_paths, use_double_scores=params.use_double_scores, - lattice_score_scale=params.lattice_score_scale, + nbest_scale=params.nbest_scale, ) - key = f"no_rescore-scale-{params.lattice_score_scale}-{params.num_paths}" # noqa + key = f"no_rescore-nbest-scale-{params.nbest_scale}-{params.num_paths}" # noqa hyps = get_texts(best_path) hyps = [[word_table[i] for i in ids] for ids in hyps] @@ -293,7 +344,7 @@ def decode_one_batch( G=G, num_paths=params.num_paths, lm_scale_list=lm_scale_list, - lattice_score_scale=params.lattice_score_scale, + nbest_scale=params.nbest_scale, ) elif params.method == "whole-lattice-rescoring": best_path_dict = rescore_with_whole_lattice( @@ -319,7 +370,7 @@ def decode_one_batch( memory_key_padding_mask=memory_key_padding_mask, sos_id=sos_id, eos_id=eos_id, - lattice_score_scale=params.lattice_score_scale, + nbest_scale=params.nbest_scale, ) else: assert False, f"Unsupported decoding method: {params.method}" @@ -340,12 +391,14 @@ def decode_dataset( dl: torch.utils.data.DataLoader, params: AttributeDict, model: nn.Module, - HLG: k2.Fsa, + HLG: Optional[k2.Fsa], + H: Optional[k2.Fsa], + bpe_model: Optional[spm.SentencePieceProcessor], word_table: k2.SymbolTable, sos_id: int, eos_id: int, G: Optional[k2.Fsa] = None, -) -> Dict[str, List[Tuple[List[int], List[int]]]]: +) -> Dict[str, List[Tuple[List[str], List[str]]]]: """Decode dataset. Args: @@ -356,7 +409,11 @@ def decode_dataset( model: The neural model. HLG: - The decoding graph. + The decoding graph. Used only when params.method is NOT ctc-decoding. + H: + The ctc topo. Used only when params.method is ctc-decoding. + bpe_model: + The BPE model. Used only when params.method is ctc-decoding. word_table: It is the word symbol table. sos_id: @@ -391,6 +448,8 @@ def decode_dataset( params=params, model=model, HLG=HLG, + H=H, + bpe_model=bpe_model, batch=batch, word_table=word_table, G=G, @@ -469,6 +528,8 @@ def main(): parser = get_parser() LibriSpeechAsrDataModule.add_arguments(parser) args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + args.lang_dir = Path(args.lang_dir) params = get_params() params.update(vars(args)) @@ -496,14 +557,26 @@ def main(): sos_id = graph_compiler.sos_id eos_id = graph_compiler.eos_id - HLG = k2.Fsa.from_dict( - torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu") - ) - HLG = HLG.to(device) - assert HLG.requires_grad is False + if params.method == "ctc-decoding": + HLG = None + H = k2.ctc_topo( + max_token=max_token_id, + modified=False, + device=device, + ) + bpe_model = spm.SentencePieceProcessor() + bpe_model.load(str(params.lang_dir / "bpe.model")) + else: + H = None + bpe_model = None + HLG = k2.Fsa.from_dict( + torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu") + ) + HLG = HLG.to(device) + assert HLG.requires_grad is False - if not hasattr(HLG, "lm_scores"): - HLG.lm_scores = HLG.scores.clone() + if not hasattr(HLG, "lm_scores"): + HLG.lm_scores = HLG.scores.clone() if params.method in ( "nbest-rescoring", @@ -593,6 +666,8 @@ def main(): params=params, model=model, HLG=HLG, + H=H, + bpe_model=bpe_model, word_table=lexicon.word_table, G=G, sos_id=sos_id, diff --git a/egs/librispeech/ASR/conformer_ctc/pretrained.py b/egs/librispeech/ASR/conformer_ctc/pretrained.py index c924b87bb..00812d674 100755 --- a/egs/librispeech/ASR/conformer_ctc/pretrained.py +++ b/egs/librispeech/ASR/conformer_ctc/pretrained.py @@ -125,7 +125,7 @@ def get_parser(): ) parser.add_argument( - "--lattice-score-scale", + "--nbest-scale", type=float, default=0.5, help=""" @@ -301,7 +301,7 @@ def main(): lattice = get_lattice( nnet_output=nnet_output, - HLG=HLG, + decoding_graph=HLG, supervision_segments=supervision_segments, search_beam=params.search_beam, output_beam=params.output_beam, @@ -336,7 +336,7 @@ def main(): memory_key_padding_mask=memory_key_padding_mask, sos_id=params.sos_id, eos_id=params.eos_id, - lattice_score_scale=params.lattice_score_scale, + nbest_scale=params.nbest_scale, ngram_lm_scale=params.ngram_lm_scale, attention_scale=params.attention_decoder_scale, ) diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py b/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py index 1e91b1008..54c2f7a6b 100755 --- a/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py +++ b/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py @@ -97,7 +97,7 @@ def get_parser(): ) parser.add_argument( - "--lattice-score-scale", + "--nbest-scale", type=float, default=0.5, help="""The scale to be applied to `lattice.scores`. @@ -146,7 +146,7 @@ def decode_one_batch( batch: dict, lexicon: Lexicon, G: Optional[k2.Fsa] = None, -) -> Dict[str, List[List[int]]]: +) -> Dict[str, List[List[str]]]: """Decode one batch and return the result in a dict. The dict has the following format: @@ -210,7 +210,7 @@ def decode_one_batch( lattice = get_lattice( nnet_output=nnet_output, - HLG=HLG, + decoding_graph=HLG, supervision_segments=supervision_segments, search_beam=params.search_beam, output_beam=params.output_beam, @@ -229,7 +229,7 @@ def decode_one_batch( lattice=lattice, num_paths=params.num_paths, use_double_scores=params.use_double_scores, - lattice_score_scale=params.lattice_score_scale, + nbest_scale=params.nbest_scale, ) key = f"no_rescore-{params.num_paths}" hyps = get_texts(best_path) @@ -248,7 +248,7 @@ def decode_one_batch( G=G, num_paths=params.num_paths, lm_scale_list=lm_scale_list, - lattice_score_scale=params.lattice_score_scale, + nbest_scale=params.nbest_scale, ) else: best_path_dict = rescore_with_whole_lattice( @@ -272,7 +272,7 @@ def decode_dataset( HLG: k2.Fsa, lexicon: Lexicon, G: Optional[k2.Fsa] = None, -) -> Dict[str, List[Tuple[List[int], List[int]]]]: +) -> Dict[str, List[Tuple[List[str], List[str]]]]: """Decode dataset. Args: diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/pretrained.py b/egs/librispeech/ASR/tdnn_lstm_ctc/pretrained.py index 0a543d859..2baeb6bba 100755 --- a/egs/librispeech/ASR/tdnn_lstm_ctc/pretrained.py +++ b/egs/librispeech/ASR/tdnn_lstm_ctc/pretrained.py @@ -232,7 +232,7 @@ def main(): lattice = get_lattice( nnet_output=nnet_output, - HLG=HLG, + decoding_graph=HLG, supervision_segments=supervision_segments, search_beam=params.search_beam, output_beam=params.output_beam, diff --git a/egs/yesno/ASR/tdnn/decode.py b/egs/yesno/ASR/tdnn/decode.py index 325acf316..57122235a 100755 --- a/egs/yesno/ASR/tdnn/decode.py +++ b/egs/yesno/ASR/tdnn/decode.py @@ -124,7 +124,7 @@ def decode_one_batch( lattice = get_lattice( nnet_output=nnet_output, - HLG=HLG, + decoding_graph=HLG, supervision_segments=supervision_segments, search_beam=params.search_beam, output_beam=params.output_beam, diff --git a/egs/yesno/ASR/tdnn/pretrained.py b/egs/yesno/ASR/tdnn/pretrained.py index fb92110e3..14220be19 100755 --- a/egs/yesno/ASR/tdnn/pretrained.py +++ b/egs/yesno/ASR/tdnn/pretrained.py @@ -175,7 +175,7 @@ def main(): lattice = get_lattice( nnet_output=nnet_output, - HLG=HLG, + decoding_graph=HLG, supervision_segments=supervision_segments, search_beam=params.search_beam, output_beam=params.output_beam, diff --git a/icefall/decode.py b/icefall/decode.py index e678e4622..62d27dd68 100644 --- a/icefall/decode.py +++ b/icefall/decode.py @@ -66,7 +66,7 @@ def _intersect_device( def get_lattice( nnet_output: torch.Tensor, - HLG: k2.Fsa, + decoding_graph: k2.Fsa, supervision_segments: torch.Tensor, search_beam: float, output_beam: float, @@ -79,8 +79,9 @@ def get_lattice( Args: nnet_output: It is the output of a neural model of shape `(N, T, C)`. - HLG: - An Fsa, the decoding graph. See also `compile_HLG.py`. + decoding_graph: + An Fsa, the decoding graph. It can be either an HLG + (see `compile_HLG.py`) or an H (see `k2.ctc_topo`). supervision_segments: A 2-D **CPU** tensor of dtype `torch.int32` with 3 columns. Each row contains information for a supervision segment. Column 0 @@ -117,7 +118,7 @@ def get_lattice( ) lattice = k2.intersect_dense_pruned( - HLG, + decoding_graph, dense_fsa_vec, search_beam=search_beam, output_beam=output_beam, @@ -180,7 +181,7 @@ class Nbest(object): lattice: k2.Fsa, num_paths: int, use_double_scores: bool = True, - lattice_score_scale: float = 0.5, + nbest_scale: float = 0.5, ) -> "Nbest": """Construct an Nbest object by **sampling** `num_paths` from a lattice. @@ -206,7 +207,7 @@ class Nbest(object): Return an Nbest instance. """ saved_scores = lattice.scores.clone() - lattice.scores *= lattice_score_scale + lattice.scores *= nbest_scale # path is a ragged tensor with dtype torch.int32. # It has three axes [utt][path][arc_pos] path = k2.random_paths( @@ -446,7 +447,7 @@ def nbest_decoding( lattice: k2.Fsa, num_paths: int, use_double_scores: bool = True, - lattice_score_scale: float = 1.0, + nbest_scale: float = 1.0, ) -> k2.Fsa: """It implements something like CTC prefix beam search using n-best lists. @@ -474,7 +475,7 @@ def nbest_decoding( use_double_scores: True to use double precision floating point in the computation. False to use single precision. - lattice_score_scale: + nbest_scale: It's the scale applied to the `lattice.scores`. A smaller value leads to more unique paths at the risk of missing the correct path. Returns: @@ -484,7 +485,7 @@ def nbest_decoding( lattice=lattice, num_paths=num_paths, use_double_scores=use_double_scores, - lattice_score_scale=lattice_score_scale, + nbest_scale=nbest_scale, ) # nbest.fsa.scores contains 0s @@ -505,7 +506,7 @@ def nbest_oracle( ref_texts: List[str], word_table: k2.SymbolTable, use_double_scores: bool = True, - lattice_score_scale: float = 0.5, + nbest_scale: float = 0.5, oov: str = "", ) -> Dict[str, List[List[int]]]: """Select the best hypothesis given a lattice and a reference transcript. @@ -517,7 +518,7 @@ def nbest_oracle( The decoding result returned from this function is the best result that we can obtain using n-best decoding with all kinds of rescoring techniques. - This function is useful to tune the value of `lattice_score_scale`. + This function is useful to tune the value of `nbest_scale`. Args: lattice: @@ -533,7 +534,7 @@ def nbest_oracle( use_double_scores: True to use double precision for computation. False to use single precision. - lattice_score_scale: + nbest_scale: It's the scale applied to the lattice.scores. A smaller value yields more unique paths. oov: @@ -549,7 +550,7 @@ def nbest_oracle( lattice=lattice, num_paths=num_paths, use_double_scores=use_double_scores, - lattice_score_scale=lattice_score_scale, + nbest_scale=nbest_scale, ) hyps = nbest.build_levenshtein_graphs() @@ -590,7 +591,7 @@ def rescore_with_n_best_list( G: k2.Fsa, num_paths: int, lm_scale_list: List[float], - lattice_score_scale: float = 1.0, + nbest_scale: float = 1.0, use_double_scores: bool = True, ) -> Dict[str, k2.Fsa]: """Rescore an n-best list with an n-gram LM. @@ -607,7 +608,7 @@ def rescore_with_n_best_list( Size of nbest list. lm_scale_list: A list of float representing LM score scales. - lattice_score_scale: + nbest_scale: Scale to be applied to ``lattice.score`` when sampling paths using ``k2.random_paths``. use_double_scores: @@ -631,7 +632,7 @@ def rescore_with_n_best_list( lattice=lattice, num_paths=num_paths, use_double_scores=use_double_scores, - lattice_score_scale=lattice_score_scale, + nbest_scale=nbest_scale, ) # nbest.fsa.scores are all 0s at this point @@ -769,7 +770,7 @@ def rescore_with_attention_decoder( memory_key_padding_mask: Optional[torch.Tensor], sos_id: int, eos_id: int, - lattice_score_scale: float = 1.0, + nbest_scale: float = 1.0, ngram_lm_scale: Optional[float] = None, attention_scale: Optional[float] = None, use_double_scores: bool = True, @@ -796,7 +797,7 @@ def rescore_with_attention_decoder( The token ID for SOS. eos_id: The token ID for EOS. - lattice_score_scale: + nbest_scale: It's the scale applied to `lattice.scores`. A smaller value leads to more unique paths at the risk of missing the correct path. ngram_lm_scale: @@ -812,7 +813,7 @@ def rescore_with_attention_decoder( lattice=lattice, num_paths=num_paths, use_double_scores=use_double_scores, - lattice_score_scale=lattice_score_scale, + nbest_scale=nbest_scale, ) # nbest.fsa.scores are all 0s at this point diff --git a/test/test_decode.py b/test/test_decode.py index 7ef127781..97964ac67 100644 --- a/test/test_decode.py +++ b/test/test_decode.py @@ -43,7 +43,7 @@ def test_nbest_from_lattice(): lattice=lattice, num_paths=10, use_double_scores=True, - lattice_score_scale=0.5, + nbest_scale=0.5, ) # each lattice has only 4 distinct paths that have different word sequences: # 10->30 From adb068eb8242fe79dafce5a100c3fdfad934c7a5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Fri, 1 Oct 2021 04:43:08 -0400 Subject: [PATCH 02/14] setup.py (#64) --- .gitignore | 1 + setup.py | 31 +++++++++++++++++++++++++++++++ 2 files changed, 32 insertions(+) create mode 100644 setup.py diff --git a/.gitignore b/.gitignore index e6c84ca5e..f4f703243 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ +icefall.egg-info/ data __pycache__ path.sh diff --git a/setup.py b/setup.py new file mode 100644 index 000000000..6c720e121 --- /dev/null +++ b/setup.py @@ -0,0 +1,31 @@ +#!/usr/bin/env python3 + +from setuptools import find_packages, setup +from pathlib import Path + +icefall_dir = Path(__file__).parent +install_requires = (icefall_dir / "requirements.txt").read_text().splitlines() + +setup( + name="icefall", + version="1.0", + python_requires=">=3.6.0", + description="Speech processing recipes using k2 and Lhotse.", + author="The k2 and Lhotse Development Team", + license="Apache-2.0 License", + packages=find_packages(), + install_requires=install_requires, + classifiers=[ + "Development Status :: 3 - Alpha", + "Programming Language :: Python :: 3.6", + "Programming Language :: Python :: 3.7", + "Programming Language :: Python :: 3.8", + "Intended Audience :: Science/Research", + "Operating System :: POSIX :: Linux", + "License :: OSI Approved :: Apache Software License", + "Topic :: Multimedia :: Sound/Audio :: Speech", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + "Topic :: Software Development :: Libraries :: Python Modules", + "Typing :: Typed", + ], +) From b682467e4d5b3cc085d7fdb988f61e98979ae6a0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Fri, 8 Oct 2021 22:32:13 -0400 Subject: [PATCH 03/14] Use BucketingSampler for dev and test data --- .../ASR/tdnn_lstm_ctc/asr_datamodule.py | 16 ++++++++-------- egs/yesno/ASR/tdnn/asr_datamodule.py | 17 ++++++++--------- 2 files changed, 16 insertions(+), 17 deletions(-) diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py index 8290e71d1..4953e8538 100644 --- a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py +++ b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py @@ -21,6 +21,10 @@ from functools import lru_cache from pathlib import Path from typing import List, Union +from torch.utils.data import DataLoader + +from icefall.dataset.datamodule import DataModule +from icefall.utils import str2bool from lhotse import CutSet, Fbank, FbankConfig, load_manifest from lhotse.dataset import ( BucketingSampler, @@ -32,10 +36,6 @@ from lhotse.dataset import ( SpecAugment, ) from lhotse.dataset.input_strategies import OnTheFlyFeatures -from torch.utils.data import DataLoader - -from icefall.dataset.datamodule import DataModule -from icefall.utils import str2bool class LibriSpeechAsrDataModule(DataModule): @@ -267,7 +267,7 @@ class LibriSpeechAsrDataModule(DataModule): cut_transforms=transforms, return_cuts=self.args.return_cuts, ) - valid_sampler = SingleCutSampler( + valid_sampler = BucketingSampler( cuts_valid, max_duration=self.args.max_duration, shuffle=False, @@ -300,12 +300,12 @@ class LibriSpeechAsrDataModule(DataModule): else PrecomputedFeatures(), return_cuts=self.args.return_cuts, ) - sampler = SingleCutSampler( - cuts_test, max_duration=self.args.max_duration + sampler = BucketingSampler( + cuts_test, max_duration=self.args.max_duration, shuffle=False ) logging.debug("About to create test dataloader") test_dl = DataLoader( - test, batch_size=None, sampler=sampler, num_workers=1 + test, batch_size=None, sampler=sampler, num_workers=self.args.num_workers ) test_loaders.append(test_dl) diff --git a/egs/yesno/ASR/tdnn/asr_datamodule.py b/egs/yesno/ASR/tdnn/asr_datamodule.py index e6614e3ce..a9a6145f0 100644 --- a/egs/yesno/ASR/tdnn/asr_datamodule.py +++ b/egs/yesno/ASR/tdnn/asr_datamodule.py @@ -20,19 +20,18 @@ from functools import lru_cache from pathlib import Path from typing import List +from torch.utils.data import DataLoader + +from icefall.dataset.datamodule import DataModule +from icefall.utils import str2bool from lhotse import CutSet, Fbank, FbankConfig, load_manifest from lhotse.dataset import ( BucketingSampler, CutConcatenate, K2SpeechRecognitionDataset, PrecomputedFeatures, - SingleCutSampler, ) from lhotse.dataset.input_strategies import OnTheFlyFeatures -from torch.utils.data import DataLoader - -from icefall.dataset.datamodule import DataModule -from icefall.utils import str2bool class YesNoAsrDataModule(DataModule): @@ -198,7 +197,7 @@ class YesNoAsrDataModule(DataModule): ) else: logging.info("Using SingleCutSampler.") - train_sampler = SingleCutSampler( + train_sampler = BucketingSampler( cuts_train, max_duration=self.args.max_duration, shuffle=self.args.shuffle, @@ -226,12 +225,12 @@ class YesNoAsrDataModule(DataModule): else PrecomputedFeatures(), return_cuts=self.args.return_cuts, ) - sampler = SingleCutSampler( - cuts_test, max_duration=self.args.max_duration + sampler = BucketingSampler( + cuts_test, max_duration=self.args.max_duration, shuffle=False ) logging.debug("About to create test dataloader") test_dl = DataLoader( - test, batch_size=None, sampler=sampler, num_workers=1 + test, batch_size=None, sampler=sampler, num_workers=self.args.num_workers ) return test_dl From 6e43905d124566dc0779e9c900b0a309c21850f0 Mon Sep 17 00:00:00 2001 From: Mingshuang Luo <37799481+luomingshuang@users.noreply.github.com> Date: Sat, 9 Oct 2021 11:56:25 +0800 Subject: [PATCH 04/14] Update the documentation to include "ctc-decoding" (#71) * Update conformer_ctc.rst --- .../recipes/librispeech/conformer_ctc.rst | 68 ++++++++++++++++++- 1 file changed, 66 insertions(+), 2 deletions(-) diff --git a/docs/source/recipes/librispeech/conformer_ctc.rst b/docs/source/recipes/librispeech/conformer_ctc.rst index a8b0683f4..73c5503d8 100644 --- a/docs/source/recipes/librispeech/conformer_ctc.rst +++ b/docs/source/recipes/librispeech/conformer_ctc.rst @@ -292,9 +292,18 @@ The commonly used options are: - ``--method`` - This specifies the decoding method. + This specifies the decoding method. This script supports 7 decoding methods. + As for ctc decoding, it uses a sentence piece model to convert word pieces to words. + And it needs neither a lexicon nor an n-gram LM. + + For example, the following command uses CTC topology for decoding: + + .. code-block:: - The following command uses attention decoder for rescoring: + $ cd egs/librispeech/ASR + $ ./conformer_ctc/decode.py --method ctc-decoding --max-duration 300 + + And the following command uses attention decoder for rescoring: .. code-block:: @@ -311,6 +320,61 @@ The commonly used options are: It has the same meaning as the one during training. A larger value may cause OOM. +Here are some results for CTC decoding with a vocab size of 500: + +Usage: + +.. code-block:: bash + + $ cd egs/librispeech/ASR + $ ./conformer_ctc/decode.py \ + --epoch 25 \ + --avg 1 \ + --max-duration 300 \ + --exp-dir conformer_ctc/exp \ + --lang-dir data/lang_bpe_500 \ + --method ctc-decoding + +The output is given below: + +.. code-block:: bash + + 2021-09-26 12:44:31,033 INFO [decode.py:537] Decoding started + 2021-09-26 12:44:31,033 INFO [decode.py:538] + {'lm_dir': PosixPath('data/lm'), 'subsampling_factor': 4, 'vgg_frontend': False, 'use_feat_batchnorm': True, + 'feature_dim': 80, 'nhead': 8, 'attention_dim': 512, 'num_decoder_layers': 6, 'search_beam': 20, 'output_beam': 8, + 'min_active_states': 30, 'max_active_states': 10000, 'use_double_scores': True, + 'epoch': 25, 'avg': 1, 'method': 'ctc-decoding', 'num_paths': 100, 'nbest_scale': 0.5, + 'export': False, 'exp_dir': PosixPath('conformer_ctc/exp'), 'lang_dir': PosixPath('data/lang_bpe_500'), 'full_libri': False, + 'feature_dir': PosixPath('data/fbank'), 'max_duration': 100, 'bucketing_sampler': False, 'num_buckets': 30, + 'concatenate_cuts': False, 'duration_factor': 1.0, 'gap': 1.0, 'on_the_fly_feats': False, + 'shuffle': True, 'return_cuts': True, 'num_workers': 2} + 2021-09-26 12:44:31,406 INFO [lexicon.py:113] Loading pre-compiled data/lang_bpe_500/Linv.pt + 2021-09-26 12:44:31,464 INFO [decode.py:548] device: cuda:0 + 2021-09-26 12:44:36,171 INFO [checkpoint.py:92] Loading checkpoint from conformer_ctc/exp/epoch-25.pt + 2021-09-26 12:44:36,776 INFO [decode.py:652] Number of model parameters: 109226120 + 2021-09-26 12:44:37,714 INFO [decode.py:473] batch 0/206, cuts processed until now is 12 + 2021-09-26 12:45:15,944 INFO [decode.py:473] batch 100/206, cuts processed until now is 1328 + 2021-09-26 12:45:54,443 INFO [decode.py:473] batch 200/206, cuts processed until now is 2563 + 2021-09-26 12:45:56,411 INFO [decode.py:494] The transcripts are stored in conformer_ctc/exp/recogs-test-clean-ctc-decoding.txt + 2021-09-26 12:45:56,592 INFO [utils.py:331] [test-clean-ctc-decoding] %WER 3.26% [1715 / 52576, 163 ins, 128 del, 1424 sub ] + 2021-09-26 12:45:56,807 INFO [decode.py:506] Wrote detailed error stats to conformer_ctc/exp/errs-test-clean-ctc-decoding.txt + 2021-09-26 12:45:56,808 INFO [decode.py:522] + For test-clean, WER of different settings are: + ctc-decoding 3.26 best for test-clean + + 2021-09-26 12:45:57,362 INFO [decode.py:473] batch 0/203, cuts processed until now is 15 + 2021-09-26 12:46:35,565 INFO [decode.py:473] batch 100/203, cuts processed until now is 1477 + 2021-09-26 12:47:15,106 INFO [decode.py:473] batch 200/203, cuts processed until now is 2922 + 2021-09-26 12:47:16,131 INFO [decode.py:494] The transcripts are stored in conformer_ctc/exp/recogs-test-other-ctc-decoding.txt + 2021-09-26 12:47:16,208 INFO [utils.py:331] [test-other-ctc-decoding] %WER 8.21% [4295 / 52343, 396 ins, 315 del, 3584 sub ] + 2021-09-26 12:47:16,432 INFO [decode.py:506] Wrote detailed error stats to conformer_ctc/exp/errs-test-other-ctc-decoding.txt + 2021-09-26 12:47:16,432 INFO [decode.py:522] + For test-other, WER of different settings are: + ctc-decoding 8.21 best for test-other + + 2021-09-26 12:47:16,433 INFO [decode.py:680] Done! + Pre-trained Model ----------------- From 069ebaf9bab80209359b5ed19a3756cde2a551f7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Sat, 9 Oct 2021 14:45:46 +0000 Subject: [PATCH 05/14] Reformatting --- egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py | 5 ++++- egs/yesno/ASR/tdnn/asr_datamodule.py | 5 ++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py index 4953e8538..229575db6 100644 --- a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py +++ b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py @@ -305,7 +305,10 @@ class LibriSpeechAsrDataModule(DataModule): ) logging.debug("About to create test dataloader") test_dl = DataLoader( - test, batch_size=None, sampler=sampler, num_workers=self.args.num_workers + test, + batch_size=None, + sampler=sampler, + num_workers=self.args.num_workers, ) test_loaders.append(test_dl) diff --git a/egs/yesno/ASR/tdnn/asr_datamodule.py b/egs/yesno/ASR/tdnn/asr_datamodule.py index a9a6145f0..832fd556e 100644 --- a/egs/yesno/ASR/tdnn/asr_datamodule.py +++ b/egs/yesno/ASR/tdnn/asr_datamodule.py @@ -230,7 +230,10 @@ class YesNoAsrDataModule(DataModule): ) logging.debug("About to create test dataloader") test_dl = DataLoader( - test, batch_size=None, sampler=sampler, num_workers=self.args.num_workers + test, + batch_size=None, + sampler=sampler, + num_workers=self.args.num_workers, ) return test_dl From beb54ddb61897e0585e318096fdde16386378e2a Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Tue, 12 Oct 2021 14:55:05 +0800 Subject: [PATCH 06/14] Support torch script. (#65) * WIP: Support torchscript. * Minor fixes. * Fix style issues. * Add documentation about how to deploy a trained model. --- README.md | 15 ++ .../recipes/librispeech/conformer_ctc.rst | 143 +++++++++++++-- egs/librispeech/ASR/conformer_ctc/export.py | 165 ++++++++++++++++++ .../ASR/conformer_ctc/transformer.py | 34 ++-- egs/librispeech/ASR/prepare.sh | 1 + 5 files changed, 330 insertions(+), 28 deletions(-) create mode 100755 egs/librispeech/ASR/conformer_ctc/export.py diff --git a/README.md b/README.md index dc03c5883..298feca2e 100644 --- a/README.md +++ b/README.md @@ -55,7 +55,22 @@ The WER for this model is: We provide a Colab notebook to run a pre-trained TDNN LSTM CTC model: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1kNmDXNMwREi0rZGAOIAOJo93REBuOTcd?usp=sharing) + +## Deployment with C++ + +Once you have trained a model in icefall, you may want to deploy it with C++, +without Python dependencies. + +Please refer to the documentation + +for how to do this. + +We also provide a Colab notebook, showing you how to run a torch scripted model in [k2][k2] with C++. +Please see: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1BIGLWzS36isskMXHKcqC9ysN6pspYXs_?usp=sharing) + + [LibriSpeech_tdnn_lstm_ctc]: egs/librispeech/ASR/tdnn_lstm_ctc [LibriSpeech_conformer_ctc]: egs/librispeech/ASR/conformer_ctc [yesno]: egs/yesno/ASR [librispeech]: egs/librispeech/ASR +[k2]: https://github.com/k2-fsa/k2 diff --git a/docs/source/recipes/librispeech/conformer_ctc.rst b/docs/source/recipes/librispeech/conformer_ctc.rst index 73c5503d8..84e99306f 100644 --- a/docs/source/recipes/librispeech/conformer_ctc.rst +++ b/docs/source/recipes/librispeech/conformer_ctc.rst @@ -20,6 +20,7 @@ In this tutorial, you will learn: - (2) How to start the training, either with a single GPU or multiple GPUs - (3) How to do decoding after training, with n-gram LM rescoring and attention decoder rescoring - (4) How to use a pre-trained model, provided by us + - (5) How to deploy your trained model in C++, without Python dependencies Data preparation ---------------- @@ -292,12 +293,12 @@ The commonly used options are: - ``--method`` - This specifies the decoding method. This script supports 7 decoding methods. - As for ctc decoding, it uses a sentence piece model to convert word pieces to words. + This specifies the decoding method. This script supports 7 decoding methods. + As for ctc decoding, it uses a sentence piece model to convert word pieces to words. And it needs neither a lexicon nor an n-gram LM. - + For example, the following command uses CTC topology for decoding: - + .. code-block:: $ cd egs/librispeech/ASR @@ -334,20 +335,20 @@ Usage: --exp-dir conformer_ctc/exp \ --lang-dir data/lang_bpe_500 \ --method ctc-decoding - + The output is given below: .. code-block:: bash 2021-09-26 12:44:31,033 INFO [decode.py:537] Decoding started - 2021-09-26 12:44:31,033 INFO [decode.py:538] - {'lm_dir': PosixPath('data/lm'), 'subsampling_factor': 4, 'vgg_frontend': False, 'use_feat_batchnorm': True, + 2021-09-26 12:44:31,033 INFO [decode.py:538] + {'lm_dir': PosixPath('data/lm'), 'subsampling_factor': 4, 'vgg_frontend': False, 'use_feat_batchnorm': True, 'feature_dim': 80, 'nhead': 8, 'attention_dim': 512, 'num_decoder_layers': 6, 'search_beam': 20, 'output_beam': 8, - 'min_active_states': 30, 'max_active_states': 10000, 'use_double_scores': True, - 'epoch': 25, 'avg': 1, 'method': 'ctc-decoding', 'num_paths': 100, 'nbest_scale': 0.5, - 'export': False, 'exp_dir': PosixPath('conformer_ctc/exp'), 'lang_dir': PosixPath('data/lang_bpe_500'), 'full_libri': False, - 'feature_dir': PosixPath('data/fbank'), 'max_duration': 100, 'bucketing_sampler': False, 'num_buckets': 30, - 'concatenate_cuts': False, 'duration_factor': 1.0, 'gap': 1.0, 'on_the_fly_feats': False, + 'min_active_states': 30, 'max_active_states': 10000, 'use_double_scores': True, + 'epoch': 25, 'avg': 1, 'method': 'ctc-decoding', 'num_paths': 100, 'nbest_scale': 0.5, + 'export': False, 'exp_dir': PosixPath('conformer_ctc/exp'), 'lang_dir': PosixPath('data/lang_bpe_500'), 'full_libri': False, + 'feature_dir': PosixPath('data/fbank'), 'max_duration': 100, 'bucketing_sampler': False, 'num_buckets': 30, + 'concatenate_cuts': False, 'duration_factor': 1.0, 'gap': 1.0, 'on_the_fly_feats': False, 'shuffle': True, 'return_cuts': True, 'num_workers': 2} 2021-09-26 12:44:31,406 INFO [lexicon.py:113] Loading pre-compiled data/lang_bpe_500/Linv.pt 2021-09-26 12:44:31,464 INFO [decode.py:548] device: cuda:0 @@ -373,7 +374,7 @@ The output is given below: For test-other, WER of different settings are: ctc-decoding 8.21 best for test-other - 2021-09-26 12:47:16,433 INFO [decode.py:680] Done! + 2021-09-26 12:47:16,433 INFO [decode.py:680] Done! Pre-trained Model ----------------- @@ -693,3 +694,119 @@ We do provide a colab notebook for this recipe showing how to use a pre-trained **Congratulations!** You have finished the librispeech ASR recipe with conformer CTC models in ``icefall``. + +If you want to deploy your trained model in C++, please read the following section. + +Deployment with C++ +------------------- + +This section describes how to deploy your trained model in C++, without +Python dependencies. + +We assume you have run ``./prepare.sh`` and have the following directories available: + +.. code-block:: bash + + data + |-- lang_bpe + +Also, we assume your checkpoints are saved in ``conformer_ctc/exp``. + +If you know that averaging 20 checkpoints starting from ``epoch-30.pt`` yields the +lowest WER, you can run the following commands + +.. code-block:: + + $ cd egs/librispeech/ASR + $ ./conformer_ctc/export.py \ + --epoch 30 \ + --avg 20 \ + --jit 1 \ + --lang-dir data/lang_bpe \ + --exp-dir conformer_ctc/exp + +to get a torch scripted model saved in ``conformer_ctc/exp/cpu_jit.pt``. + +Now you have all needed files ready. Let us compile k2 from source: + +.. code-block:: bash + + $ cd $HOME + $ git clone https://github.com/k2-fsa/k2 + $ cd k2 + $ git checkout v2.0-pre + +.. CAUTION:: + + You have to switch to the branch ``v2.0-pre``! + +.. code-block:: bash + + $ mkdir build-release + $ cd build-release + $ cmake -DCMAKE_BUILD_TYPE=Release .. + $ make -j decode + # You will find an executable: `./bin/decode` + +Now you are ready to go! + +To view the usage of ``./bin/decode``, run: + +.. code-block:: + + $ ./bin/decode + +It will show you the following message: + +.. code-block:: + + Please provide --jit_pt + + (1) CTC decoding + ./bin/decode \ + --use_ctc_decoding true \ + --jit_pt \ + --bpe_model \ + /path/to/foo.wav \ + /path/to/bar.wav \ + + (2) HLG decoding + ./bin/decode \ + --use_ctc_decoding false \ + --jit_pt \ + --hlg \ + --word-table \ + /path/to/foo.wav \ + /path/to/bar.wav \ + + + --use_gpu false to use CPU + --use_gpu true to use GPU + +``./bin/decode`` supports two types of decoding at present: CTC decoding and HLG decoding. + +CTC decoding +^^^^^^^^^^^^ + +You need to provide: + + - ``--jit_pt``, this is the file generated by ``conformer_ctc/export.py``. You can find it + in ``conformer_ctc/exp/cpu_jit.pt``. + - ``--bpe_model``, this is a sentence piece model generated by ``prepare.sh``. You can find + it in ``data/lang_bpe/bpe.model``. + + +HLG decoding +^^^^^^^^^^^^ + +You need to provide: + + - ``--jit_pt``, this is the same file as in CTC decoding. + - ``--hlg``, this file is generated by ``prepare.sh``. You can find it in ``data/lang_bpe/HLG.pt``. + - ``--word-table``, this file is generated by ``prepare.sh``. You can find it in ``data/lang_bpe/words.txt``. + +We do provide a Colab notebook, showing you how to run a torch scripted model in C++. +Please see |librispeech asr conformer ctc torch script colab notebook| + +.. |librispeech asr conformer ctc torch script colab notebook| image:: https://colab.research.google.com/assets/colab-badge.svg + :target: https://colab.research.google.com/drive/1BIGLWzS36isskMXHKcqC9ysN6pspYXs_?usp=sharing diff --git a/egs/librispeech/ASR/conformer_ctc/export.py b/egs/librispeech/ASR/conformer_ctc/export.py new file mode 100755 index 000000000..8241c84c1 --- /dev/null +++ b/egs/librispeech/ASR/conformer_ctc/export.py @@ -0,0 +1,165 @@ +#!/usr/bin/env python3 +# +# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This script converts several saved checkpoints +# to a single one using model averaging. + +import argparse +import logging +from pathlib import Path + +import torch +from conformer import Conformer + +from icefall.checkpoint import average_checkpoints, load_checkpoint +from icefall.lexicon import Lexicon +from icefall.utils import AttributeDict, str2bool + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=34, + help="It specifies the checkpoint to use for decoding." + "Note: Epoch counts from 0.", + ) + + parser.add_argument( + "--avg", + type=int, + default=20, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="conformer_ctc/exp", + help="""It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--lang-dir", + type=str, + default="data/lang_bpe", + help="""It contains language related input files such as "lexicon.txt" + """, + ) + + parser.add_argument( + "--jit", + type=str2bool, + default=True, + help="""True to save a model after applying torch.jit.script. + """, + ) + + return parser + + +def get_params() -> AttributeDict: + params = AttributeDict( + { + "feature_dim": 80, + "subsampling_factor": 4, + "use_feat_batchnorm": True, + "attention_dim": 512, + "nhead": 8, + "num_decoder_layers": 6, + } + ) + return params + + +def main(): + args = get_parser().parse_args() + args.exp_dir = Path(args.exp_dir) + args.lang_dir = Path(args.lang_dir) + + params = get_params() + params.update(vars(args)) + + logging.info(params) + + lexicon = Lexicon(params.lang_dir) + max_token_id = max(lexicon.tokens) + num_classes = max_token_id + 1 # +1 for the blank + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + model = Conformer( + num_features=params.feature_dim, + nhead=params.nhead, + d_model=params.attention_dim, + num_classes=num_classes, + subsampling_factor=params.subsampling_factor, + num_decoder_layers=params.num_decoder_layers, + vgg_frontend=False, + use_feat_batchnorm=params.use_feat_batchnorm, + ) + model.to(device) + + if params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if start >= 0: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.load_state_dict(average_checkpoints(filenames)) + + model.to("cpu") + model.eval() + + if params.jit: + logging.info("Using torch.jit.script") + model = torch.jit.script(model) + filename = params.exp_dir / "cpu_jit.pt" + model.save(str(filename)) + logging.info(f"Saved to {filename}") + else: + logging.info("Not using torch.jit.script") + # Save it using a format so that it can be loaded + # by :func:`load_checkpoint` + filename = params.exp_dir / "pretrained.pt" + torch.save({"model": model.state_dict()}, str(filename)) + logging.info(f"Saved to {filename}") + + +if __name__ == "__main__": + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/conformer_ctc/transformer.py b/egs/librispeech/ASR/conformer_ctc/transformer.py index 68a4ff65c..a2e36a41e 100644 --- a/egs/librispeech/ASR/conformer_ctc/transformer.py +++ b/egs/librispeech/ASR/conformer_ctc/transformer.py @@ -236,6 +236,7 @@ class Transformer(nn.Module): x = nn.functional.log_softmax(x, dim=-1) # (N, T, C) return x + @torch.jit.export def decoder_forward( self, memory: torch.Tensor, @@ -264,11 +265,15 @@ class Transformer(nn.Module): """ ys_in = add_sos(token_ids, sos_id=sos_id) ys_in = [torch.tensor(y) for y in ys_in] - ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=eos_id) + ys_in_pad = pad_sequence( + ys_in, batch_first=True, padding_value=float(eos_id) + ) ys_out = add_eos(token_ids, eos_id=eos_id) ys_out = [torch.tensor(y) for y in ys_out] - ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=-1) + ys_out_pad = pad_sequence( + ys_out, batch_first=True, padding_value=float(-1) + ) device = memory.device ys_in_pad = ys_in_pad.to(device) @@ -301,6 +306,7 @@ class Transformer(nn.Module): return decoder_loss + @torch.jit.export def decoder_nll( self, memory: torch.Tensor, @@ -331,11 +337,15 @@ class Transformer(nn.Module): ys_in = add_sos(token_ids, sos_id=sos_id) ys_in = [torch.tensor(y) for y in ys_in] - ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=eos_id) + ys_in_pad = pad_sequence( + ys_in, batch_first=True, padding_value=float(eos_id) + ) ys_out = add_eos(token_ids, eos_id=eos_id) ys_out = [torch.tensor(y) for y in ys_out] - ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=-1) + ys_out_pad = pad_sequence( + ys_out, batch_first=True, padding_value=float(-1) + ) device = memory.device ys_in_pad = ys_in_pad.to(device, dtype=torch.int64) @@ -649,7 +659,8 @@ class PositionalEncoding(nn.Module): self.d_model = d_model self.xscale = math.sqrt(self.d_model) self.dropout = nn.Dropout(p=dropout) - self.pe = None + # not doing: self.pe = None because of errors thrown by torchscript + self.pe = torch.zeros(0, self.d_model, dtype=torch.float32) def extend_pe(self, x: torch.Tensor) -> None: """Extend the time t in the positional encoding if required. @@ -666,8 +677,7 @@ class PositionalEncoding(nn.Module): """ if self.pe is not None: if self.pe.size(1) >= x.size(1): - if self.pe.dtype != x.dtype or self.pe.device != x.device: - self.pe = self.pe.to(dtype=x.dtype, device=x.device) + self.pe = self.pe.to(dtype=x.dtype, device=x.device) return pe = torch.zeros(x.size(1), self.d_model, dtype=torch.float32) position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1) @@ -972,10 +982,7 @@ def add_sos(token_ids: List[List[int]], sos_id: int) -> List[List[int]]: Return a new list-of-list, where each sublist starts with SOS ID. """ - ans = [] - for utt in token_ids: - ans.append([sos_id] + utt) - return ans + return [[sos_id] + utt for utt in token_ids] def add_eos(token_ids: List[List[int]], eos_id: int) -> List[List[int]]: @@ -992,7 +999,4 @@ def add_eos(token_ids: List[List[int]], eos_id: int) -> List[List[int]]: Return a new list-of-list, where each sublist ends with EOS ID. """ - ans = [] - for utt in token_ids: - ans.append(utt + [eos_id]) - return ans + return [utt + [eos_id] for utt in token_ids] diff --git a/egs/librispeech/ASR/prepare.sh b/egs/librispeech/ASR/prepare.sh index f06e013f6..8aa972806 100755 --- a/egs/librispeech/ASR/prepare.sh +++ b/egs/librispeech/ASR/prepare.sh @@ -41,6 +41,7 @@ dl_dir=$PWD/download # data/lang_bpe_yyy if the array contains xxx, yyy vocab_sizes=( 5000 + 500 ) # All files generated by this script are saved in "data". From 597c5efdb11b8880ea1a5b62537042836cae11c1 Mon Sep 17 00:00:00 2001 From: Mingshuang Luo <37799481+luomingshuang@users.noreply.github.com> Date: Tue, 12 Oct 2021 15:58:03 +0800 Subject: [PATCH 07/14] Use LossRecord to record and print the loss for the training process (#62) * Update index.rst (AS->ASR) * Update conformer_ctc.rst (pretraind->pretrained) * Fix some spelling errors. * Fix some spelling errors. * Use LossRecord to record and print loss in the training process * Change the name "LossRecord" to "MetricsTracker" --- egs/librispeech/ASR/conformer_ctc/train.py | 186 ++++++--------------- egs/librispeech/ASR/prepare.sh | 6 +- egs/librispeech/ASR/tdnn_lstm_ctc/train.py | 112 +++++-------- egs/yesno/ASR/prepare.sh | 2 +- egs/yesno/ASR/tdnn/train.py | 102 +++++------ icefall/utils.py | 88 +++++++++- 6 files changed, 222 insertions(+), 274 deletions(-) diff --git a/egs/librispeech/ASR/conformer_ctc/train.py b/egs/librispeech/ASR/conformer_ctc/train.py index 80b2d924a..3e1049fbf 100755 --- a/egs/librispeech/ASR/conformer_ctc/train.py +++ b/egs/librispeech/ASR/conformer_ctc/train.py @@ -1,6 +1,7 @@ #!/usr/bin/env python3 # Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, -# Wei Kang) +# Wei Kang +# Mingshuang Luo) # # See ../../../../LICENSE for clarification regarding multiple authors # @@ -21,13 +22,15 @@ import argparse import logging from pathlib import Path from shutil import copyfile -from typing import Optional +from typing import Optional, Tuple + import k2 import torch -import torch.distributed as dist import torch.multiprocessing as mp import torch.nn as nn +from torch import Tensor + from asr_datamodule import LibriSpeechAsrDataModule from conformer import Conformer from lhotse.utils import fix_random_seed @@ -43,6 +46,7 @@ from icefall.dist import cleanup_dist, setup_dist from icefall.lexicon import Lexicon from icefall.utils import ( AttributeDict, + MetricsTracker, encode_supervisions, setup_logger, str2bool, @@ -287,7 +291,7 @@ def compute_loss( batch: dict, graph_compiler: BpeCtcTrainingGraphCompiler, is_training: bool, -): +) -> Tuple[Tensor, MetricsTracker]: """ Compute CTC loss given the model and its inputs. @@ -367,15 +371,17 @@ def compute_loss( loss = ctc_loss att_loss = torch.tensor([0]) - # train_frames and valid_frames are used for printing. - if is_training: - params.train_frames = supervision_segments[:, 2].sum().item() - else: - params.valid_frames = supervision_segments[:, 2].sum().item() - assert loss.requires_grad == is_training - return loss, ctc_loss.detach(), att_loss.detach() + info = MetricsTracker() + info["frames"] = supervision_segments[:, 2].sum().item() + info["ctc_loss"] = ctc_loss.detach().cpu().item() + if params.att_rate != 0.0: + info["att_loss"] = att_loss.detach().cpu().item() + + info["loss"] = loss.detach().cpu().item() + + return loss, info def compute_validation_loss( @@ -384,18 +390,14 @@ def compute_validation_loss( graph_compiler: BpeCtcTrainingGraphCompiler, valid_dl: torch.utils.data.DataLoader, world_size: int = 1, -) -> None: - """Run the validation process. The validation loss - is saved in `params.valid_loss`. - """ +) -> MetricsTracker: + """Run the validation process.""" model.eval() - tot_loss = 0.0 - tot_ctc_loss = 0.0 - tot_att_loss = 0.0 - tot_frames = 0.0 + tot_loss = MetricsTracker() + for batch_idx, batch in enumerate(valid_dl): - loss, ctc_loss, att_loss = compute_loss( + loss, loss_info = compute_loss( params=params, model=model, batch=batch, @@ -403,36 +405,17 @@ def compute_validation_loss( is_training=False, ) assert loss.requires_grad is False - assert ctc_loss.requires_grad is False - assert att_loss.requires_grad is False - - loss_cpu = loss.detach().cpu().item() - tot_loss += loss_cpu - - tot_ctc_loss += ctc_loss.detach().cpu().item() - tot_att_loss += att_loss.detach().cpu().item() - - tot_frames += params.valid_frames + tot_loss = tot_loss + loss_info if world_size > 1: - s = torch.tensor( - [tot_loss, tot_ctc_loss, tot_att_loss, tot_frames], - device=loss.device, - ) - dist.all_reduce(s, op=dist.ReduceOp.SUM) - s = s.cpu().tolist() - tot_loss = s[0] - tot_ctc_loss = s[1] - tot_att_loss = s[2] - tot_frames = s[3] + tot_loss.reduce(loss.device) - params.valid_loss = tot_loss / tot_frames - params.valid_ctc_loss = tot_ctc_loss / tot_frames - params.valid_att_loss = tot_att_loss / tot_frames - - if params.valid_loss < params.best_valid_loss: + loss_value = tot_loss["loss"] / tot_loss["frames"] + if loss_value < params.best_valid_loss: params.best_valid_epoch = params.cur_epoch - params.best_valid_loss = params.valid_loss + params.best_valid_loss = loss_value + + return tot_loss def train_one_epoch( @@ -471,24 +454,21 @@ def train_one_epoch( """ model.train() - tot_loss = 0.0 # sum of losses over all batches - tot_ctc_loss = 0.0 - tot_att_loss = 0.0 + tot_loss = MetricsTracker() - tot_frames = 0.0 # sum of frames over all batches - params.tot_loss = 0.0 - params.tot_frames = 0.0 for batch_idx, batch in enumerate(train_dl): params.batch_idx_train += 1 batch_size = len(batch["supervisions"]["text"]) - loss, ctc_loss, att_loss = compute_loss( + loss, loss_info = compute_loss( params=params, model=model, batch=batch, graph_compiler=graph_compiler, is_training=True, ) + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info # NOTE: We use reduction==sum and loss is computed over utterances # in the batch and there is no normalization to it so far. @@ -498,75 +478,26 @@ def train_one_epoch( clip_grad_norm_(model.parameters(), 5.0, 2.0) optimizer.step() - loss_cpu = loss.detach().cpu().item() - ctc_loss_cpu = ctc_loss.detach().cpu().item() - att_loss_cpu = att_loss.detach().cpu().item() - - tot_frames += params.train_frames - tot_loss += loss_cpu - tot_ctc_loss += ctc_loss_cpu - tot_att_loss += att_loss_cpu - - params.tot_frames += params.train_frames - params.tot_loss += loss_cpu - - tot_avg_loss = tot_loss / tot_frames - tot_avg_ctc_loss = tot_ctc_loss / tot_frames - tot_avg_att_loss = tot_att_loss / tot_frames - if batch_idx % params.log_interval == 0: logging.info( - f"Epoch {params.cur_epoch}, batch {batch_idx}, " - f"batch avg ctc loss {ctc_loss_cpu/params.train_frames:.4f}, " - f"batch avg att loss {att_loss_cpu/params.train_frames:.4f}, " - f"batch avg loss {loss_cpu/params.train_frames:.4f}, " - f"total avg ctc loss: {tot_avg_ctc_loss:.4f}, " - f"total avg att loss: {tot_avg_att_loss:.4f}, " - f"total avg loss: {tot_avg_loss:.4f}, " - f"batch size: {batch_size}" + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}" ) + if batch_idx % 10 == 0: + if tb_writer is not None: - tb_writer.add_scalar( - "train/current_ctc_loss", - ctc_loss_cpu / params.train_frames, - params.batch_idx_train, + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train ) - tb_writer.add_scalar( - "train/current_att_loss", - att_loss_cpu / params.train_frames, - params.batch_idx_train, + tot_loss.write_summary( + tb_writer, "train/tot_", params.batch_idx_train ) - tb_writer.add_scalar( - "train/current_loss", - loss_cpu / params.train_frames, - params.batch_idx_train, - ) - tb_writer.add_scalar( - "train/tot_avg_ctc_loss", - tot_avg_ctc_loss, - params.batch_idx_train, - ) - - tb_writer.add_scalar( - "train/tot_avg_att_loss", - tot_avg_att_loss, - params.batch_idx_train, - ) - tb_writer.add_scalar( - "train/tot_avg_loss", - tot_avg_loss, - params.batch_idx_train, - ) - if batch_idx > 0 and batch_idx % params.reset_interval == 0: - tot_loss = 0.0 # sum of losses over all batches - tot_ctc_loss = 0.0 - tot_att_loss = 0.0 - - tot_frames = 0.0 # sum of frames over all batches if batch_idx > 0 and batch_idx % params.valid_interval == 0: - compute_validation_loss( + logging.info("Computing validation loss") + valid_info = compute_validation_loss( params=params, model=model, graph_compiler=graph_compiler, @@ -574,33 +505,14 @@ def train_one_epoch( world_size=world_size, ) model.train() - logging.info( - f"Epoch {params.cur_epoch}, " - f"valid ctc loss {params.valid_ctc_loss:.4f}," - f"valid att loss {params.valid_att_loss:.4f}," - f"valid loss {params.valid_loss:.4f}," - f" best valid loss: {params.best_valid_loss:.4f} " - f"best valid epoch: {params.best_valid_epoch}" - ) + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") if tb_writer is not None: - tb_writer.add_scalar( - "train/valid_ctc_loss", - params.valid_ctc_loss, - params.batch_idx_train, - ) - tb_writer.add_scalar( - "train/valid_att_loss", - params.valid_att_loss, - params.batch_idx_train, - ) - tb_writer.add_scalar( - "train/valid_loss", - params.valid_loss, - params.batch_idx_train, + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train ) - params.train_loss = params.tot_loss / params.tot_frames - + loss_value = tot_loss["loss"] / tot_loss["frames"] + params.train_loss = loss_value if params.train_loss < params.best_train_loss: params.best_train_epoch = params.cur_epoch params.best_train_loss = params.train_loss diff --git a/egs/librispeech/ASR/prepare.sh b/egs/librispeech/ASR/prepare.sh index 8aa972806..b536cb472 100755 --- a/egs/librispeech/ASR/prepare.sh +++ b/egs/librispeech/ASR/prepare.sh @@ -57,13 +57,13 @@ log() { log "dl_dir: $dl_dir" if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then - log "stage -1: Download LM" + log "Stage -1: Download LM" [ ! -e $dl_dir/lm ] && mkdir -p $dl_dir/lm ./local/download_lm.py --out-dir=$dl_dir/lm fi if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then - log "stage 0: Download data" + log "Stage 0: Download data" # If you have pre-downloaded it to /path/to/LibriSpeech, # you can create a symlink @@ -126,7 +126,7 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then fi if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then - log "State 6: Prepare BPE based lang" + log "Stage 6: Prepare BPE based lang" for vocab_size in ${vocab_sizes[@]}; do lang_dir=data/lang_bpe_${vocab_size} diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/train.py b/egs/librispeech/ASR/tdnn_lstm_ctc/train.py index 695ee5130..51a486e07 100755 --- a/egs/librispeech/ASR/tdnn_lstm_ctc/train.py +++ b/egs/librispeech/ASR/tdnn_lstm_ctc/train.py @@ -1,5 +1,6 @@ #!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang +# Mingshuang Luo) # # See ../../../../LICENSE for clarification regarding multiple authors # @@ -20,14 +21,15 @@ import argparse import logging from pathlib import Path from shutil import copyfile -from typing import Optional +from typing import Optional, Tuple import k2 import torch -import torch.distributed as dist import torch.multiprocessing as mp import torch.nn as nn import torch.optim as optim +from torch import Tensor + from asr_datamodule import LibriSpeechAsrDataModule from lhotse.utils import fix_random_seed from model import TdnnLstm @@ -43,6 +45,7 @@ from icefall.graph_compiler import CtcTrainingGraphCompiler from icefall.lexicon import Lexicon from icefall.utils import ( AttributeDict, + MetricsTracker, encode_supervisions, setup_logger, str2bool, @@ -267,7 +270,7 @@ def compute_loss( batch: dict, graph_compiler: CtcTrainingGraphCompiler, is_training: bool, -): +) -> Tuple[Tensor, MetricsTracker]: """ Compute CTC loss given the model and its inputs. @@ -324,13 +327,11 @@ def compute_loss( assert loss.requires_grad == is_training - # train_frames and valid_frames are used for printing. - if is_training: - params.train_frames = supervision_segments[:, 2].sum().item() - else: - params.valid_frames = supervision_segments[:, 2].sum().item() + info = MetricsTracker() + info["frames"] = supervision_segments[:, 2].sum().item() + info["loss"] = loss.detach().cpu().item() - return loss + return loss, info def compute_validation_loss( @@ -339,16 +340,16 @@ def compute_validation_loss( graph_compiler: CtcTrainingGraphCompiler, valid_dl: torch.utils.data.DataLoader, world_size: int = 1, -) -> None: +) -> MetricsTracker: """Run the validation process. The validation loss is saved in `params.valid_loss`. """ model.eval() - tot_loss = 0.0 - tot_frames = 0.0 + tot_loss = MetricsTracker() + for batch_idx, batch in enumerate(valid_dl): - loss = compute_loss( + loss, loss_info = compute_loss( params=params, model=model, batch=batch, @@ -357,22 +358,18 @@ def compute_validation_loss( ) assert loss.requires_grad is False - loss_cpu = loss.detach().cpu().item() - tot_loss += loss_cpu - tot_frames += params.valid_frames + tot_loss = tot_loss + loss_info if world_size > 1: - s = torch.tensor([tot_loss, tot_frames], device=loss.device) - dist.all_reduce(s, op=dist.ReduceOp.SUM) - s = s.cpu().tolist() - tot_loss = s[0] - tot_frames = s[1] + tot_loss.reduce(loss.device) - params.valid_loss = tot_loss / tot_frames + loss_value = tot_loss["loss"] / tot_loss["frames"] - if params.valid_loss < params.best_valid_loss: + if loss_value < params.best_valid_loss: params.best_valid_epoch = params.cur_epoch - params.best_valid_loss = params.valid_loss + params.best_valid_loss = loss_value + + return tot_loss def train_one_epoch( @@ -411,67 +408,45 @@ def train_one_epoch( """ model.train() - tot_loss = 0.0 # reset after params.reset_interval of batches - tot_frames = 0.0 # reset after params.reset_interval of batches - - params.tot_loss = 0.0 - params.tot_frames = 0.0 + tot_loss = MetricsTracker() for batch_idx, batch in enumerate(train_dl): params.batch_idx_train += 1 batch_size = len(batch["supervisions"]["text"]) - loss = compute_loss( + loss, loss_info = compute_loss( params=params, model=model, batch=batch, graph_compiler=graph_compiler, is_training=True, ) - - # NOTE: We use reduction==sum and loss is computed over utterances - # in the batch and there is no normalization to it so far. + # summary stats. + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info optimizer.zero_grad() loss.backward() clip_grad_norm_(model.parameters(), 5.0, 2.0) optimizer.step() - loss_cpu = loss.detach().cpu().item() - - tot_frames += params.train_frames - tot_loss += loss_cpu - tot_avg_loss = tot_loss / tot_frames - - params.tot_frames += params.train_frames - params.tot_loss += loss_cpu - if batch_idx % params.log_interval == 0: logging.info( - f"Epoch {params.cur_epoch}, batch {batch_idx}, " - f"batch avg loss {loss_cpu/params.train_frames:.4f}, " - f"total avg loss: {tot_avg_loss:.4f}, " - f"batch size: {batch_size}" + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}" ) + if batch_idx % 10 == 0: + if tb_writer is not None: - tb_writer.add_scalar( - "train/current_loss", - loss_cpu / params.train_frames, - params.batch_idx_train, + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train ) - - tb_writer.add_scalar( - "train/tot_avg_loss", - tot_avg_loss, - params.batch_idx_train, + tot_loss.write_summary( + tb_writer, "train/tot_", params.batch_idx_train ) - if batch_idx > 0 and batch_idx % params.reset_interval == 0: - tot_loss = 0 - tot_frames = 0 - if batch_idx > 0 and batch_idx % params.valid_interval == 0: - compute_validation_loss( + valid_info = compute_validation_loss( params=params, model=model, graph_compiler=graph_compiler, @@ -479,13 +454,16 @@ def train_one_epoch( world_size=world_size, ) model.train() - logging.info( - f"Epoch {params.cur_epoch}, valid loss {params.valid_loss:.4f}," - f" best valid loss: {params.best_valid_loss:.4f} " - f"best valid epoch: {params.best_valid_epoch}" - ) + logging.info(f"Epoch {params.cur_epoch}, validation {valid_info}") + if tb_writer is not None: + valid_info.write_summary( + tb_writer, + "train/valid_", + params.batch_idx_train, + ) - params.train_loss = params.tot_loss / params.tot_frames + loss_value = tot_loss["loss"] / tot_loss["frames"] + params.train_loss = loss_value if params.train_loss < params.best_train_loss: params.best_train_epoch = params.cur_epoch diff --git a/egs/yesno/ASR/prepare.sh b/egs/yesno/ASR/prepare.sh index 9a0cc48bb..8fcee0290 100755 --- a/egs/yesno/ASR/prepare.sh +++ b/egs/yesno/ASR/prepare.sh @@ -24,7 +24,7 @@ log() { log "dl_dir: $dl_dir" if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then - log "stage 0: Download data" + log "Stage 0: Download data" mkdir -p $dl_dir if [ ! -f $dl_dir/waves_yesno/.completed ]; then diff --git a/egs/yesno/ASR/tdnn/train.py b/egs/yesno/ASR/tdnn/train.py index 0f5506d38..6cc511a28 100755 --- a/egs/yesno/ASR/tdnn/train.py +++ b/egs/yesno/ASR/tdnn/train.py @@ -4,14 +4,14 @@ import argparse import logging from pathlib import Path from shutil import copyfile -from typing import Optional +from typing import Optional, Tuple import k2 import torch -import torch.distributed as dist import torch.multiprocessing as mp import torch.nn as nn import torch.optim as optim +from torch import Tensor from asr_datamodule import YesNoAsrDataModule from lhotse.utils import fix_random_seed from model import Tdnn @@ -24,7 +24,7 @@ from icefall.checkpoint import save_checkpoint as save_checkpoint_impl from icefall.dist import cleanup_dist, setup_dist from icefall.graph_compiler import CtcTrainingGraphCompiler from icefall.lexicon import Lexicon -from icefall.utils import AttributeDict, setup_logger, str2bool +from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool def get_parser(): @@ -122,6 +122,8 @@ def get_params() -> AttributeDict: - valid_interval: Run validation if batch_idx % valid_interval` is 0 + - reset_interval: Reset statistics if batch_idx % reset_interval is 0 + - beam_size: It is used in k2.ctc_loss - reduction: It is used in k2.ctc_loss @@ -142,6 +144,7 @@ def get_params() -> AttributeDict: "best_valid_epoch": -1, "batch_idx_train": 0, "log_interval": 10, + "reset_interval": 20, "valid_interval": 10, "beam_size": 10, "reduction": "sum", @@ -245,7 +248,7 @@ def compute_loss( batch: dict, graph_compiler: CtcTrainingGraphCompiler, is_training: bool, -): +) -> Tuple[Tensor, MetricsTracker]: """ Compute CTC loss given the model and its inputs. @@ -305,13 +308,11 @@ def compute_loss( assert loss.requires_grad == is_training - # train_frames and valid_frames are used for printing. - if is_training: - params.train_frames = supervision_segments[:, 2].sum().item() - else: - params.valid_frames = supervision_segments[:, 2].sum().item() + info = MetricsTracker() + info["frames"] = supervision_segments[:, 2].sum().item() + info["loss"] = loss.detach().cpu().item() - return loss + return loss, info def compute_validation_loss( @@ -320,16 +321,16 @@ def compute_validation_loss( graph_compiler: CtcTrainingGraphCompiler, valid_dl: torch.utils.data.DataLoader, world_size: int = 1, -) -> None: +) -> MetricsTracker: """Run the validation process. The validation loss is saved in `params.valid_loss`. """ model.eval() - tot_loss = 0.0 - tot_frames = 0.0 + tot_loss = MetricsTracker() + for batch_idx, batch in enumerate(valid_dl): - loss = compute_loss( + loss, loss_info = compute_loss( params=params, model=model, batch=batch, @@ -338,22 +339,18 @@ def compute_validation_loss( ) assert loss.requires_grad is False - loss_cpu = loss.detach().cpu().item() - tot_loss += loss_cpu - tot_frames += params.valid_frames + tot_loss = tot_loss + loss_info if world_size > 1: - s = torch.tensor([tot_loss, tot_frames], device=loss.device) - dist.all_reduce(s, op=dist.ReduceOp.SUM) - s = s.cpu().tolist() - tot_loss = s[0] - tot_frames = s[1] + tot_loss.reduce(loss.device) - params.valid_loss = tot_loss / tot_frames + loss_value = tot_loss["loss"] / tot_loss["frames"] - if params.valid_loss < params.best_valid_loss: + if loss_value < params.best_valid_loss: params.best_valid_epoch = params.cur_epoch - params.best_valid_loss = params.valid_loss + params.best_valid_loss = loss_value + + return tot_loss def train_one_epoch( @@ -392,57 +389,45 @@ def train_one_epoch( """ model.train() - tot_loss = 0.0 # sum of losses over all batches - tot_frames = 0.0 # sum of frames over all batches + tot_loss = MetricsTracker() + for batch_idx, batch in enumerate(train_dl): params.batch_idx_train += 1 batch_size = len(batch["supervisions"]["text"]) - loss = compute_loss( + loss, loss_info = compute_loss( params=params, model=model, batch=batch, graph_compiler=graph_compiler, is_training=True, ) - - # NOTE: We use reduction==sum and loss is computed over utterances - # in the batch and there is no normalization to it so far. + # summary stats. + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info optimizer.zero_grad() loss.backward() clip_grad_norm_(model.parameters(), 5.0, 2.0) optimizer.step() - loss_cpu = loss.detach().cpu().item() - - tot_frames += params.train_frames - tot_loss += loss_cpu - tot_avg_loss = tot_loss / tot_frames - if batch_idx % params.log_interval == 0: logging.info( - f"Epoch {params.cur_epoch}, batch {batch_idx}, " - f"batch avg loss {loss_cpu/params.train_frames:.4f}, " - f"total avg loss: {tot_avg_loss:.4f}, " - f"batch size: {batch_size}" + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}" ) + if batch_idx % 10 == 0: if tb_writer is not None: - tb_writer.add_scalar( - "train/current_loss", - loss_cpu / params.train_frames, - params.batch_idx_train, + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train ) - - tb_writer.add_scalar( - "train/tot_avg_loss", - tot_avg_loss, - params.batch_idx_train, + tot_loss.write_summary( + tb_writer, "train/tot_", params.batch_idx_train ) if batch_idx > 0 and batch_idx % params.valid_interval == 0: - compute_validation_loss( + valid_info = compute_validation_loss( params=params, model=model, graph_compiler=graph_compiler, @@ -450,19 +435,16 @@ def train_one_epoch( world_size=world_size, ) model.train() - logging.info( - f"Epoch {params.cur_epoch}, valid loss {params.valid_loss:.4f}," - f" best valid loss: {params.best_valid_loss:.4f} " - f"best valid epoch: {params.best_valid_epoch}" - ) + logging.info(f"Epoch {params.cur_epoch}, validation {valid_info}") if tb_writer is not None: - tb_writer.add_scalar( - "train/valid_loss", - params.valid_loss, + valid_info.write_summary( + tb_writer, + "train/valid_", params.batch_idx_train, ) - params.train_loss = tot_loss / tot_frames + loss_value = tot_loss["loss"] / tot_loss["frames"] + params.train_loss = loss_value if params.train_loss < params.best_train_loss: params.best_train_epoch = params.cur_epoch diff --git a/icefall/utils.py b/icefall/utils.py index 23b4dd6c7..66aa5c601 100644 --- a/icefall/utils.py +++ b/icefall/utils.py @@ -1,4 +1,5 @@ -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang +# Mingshuang Luo) # # See ../../LICENSE for clarification regarding multiple authors # @@ -17,6 +18,7 @@ import argparse import logging +import collections import os import subprocess from collections import defaultdict @@ -29,6 +31,7 @@ import k2 import kaldialign import torch import torch.distributed as dist +from torch.utils.tensorboard import SummaryWriter Pathlike = Union[str, Path] @@ -166,8 +169,8 @@ def encode_supervisions( supervisions: dict, subsampling_factor: int ) -> Tuple[torch.Tensor, List[str]]: """ - Encodes Lhotse's ``batch["supervisions"]`` dict into a pair of torch Tensor, - and a list of transcription strings. + Encodes Lhotse's ``batch["supervisions"]`` dict into + a pair of torch Tensor, and a list of transcription strings. The supervision tensor has shape ``(batch_size, 3)``. Its second dimension contains information about sequence index [0], @@ -272,13 +275,13 @@ def write_error_stats( Errors: 23 insertions, 57 deletions, 212 substitutions, over 2606 reference words (2337 correct) - - The difference between the reference transcript and predicted results. + - The difference between the reference transcript and predicted result. An instance is given below:: THE ASSOCIATION OF (EDISON->ADDISON) ILLUMINATING COMPANIES - The above example shows that the reference word is `EDISON`, but it is - predicted to `ADDISON` (a substitution error). + The above example shows that the reference word is `EDISON`, + but it is predicted to `ADDISON` (a substitution error). Another example is:: @@ -419,3 +422,76 @@ def write_error_stats( print(f"{word} {corr} {tot_errs} {ref_count} {hyp_count}", file=f) return float(tot_err_rate) + + +class MetricsTracker(collections.defaultdict): + def __init__(self): + # Passing the type 'int' to the base-class constructor + # makes undefined items default to int() which is zero. + # This class will play a role as metrics tracker. + # It can record many metrics, including but not limited to loss. + super(MetricsTracker, self).__init__(int) + + def __add__(self, other: "MetricsTracker") -> "MetricsTracker": + ans = MetricsTracker() + for k, v in self.items(): + ans[k] = v + for k, v in other.items(): + ans[k] = ans[k] + v + return ans + + def __mul__(self, alpha: float) -> "MetricsTracker": + ans = MetricsTracker() + for k, v in self.items(): + ans[k] = v * alpha + return ans + + def __str__(self) -> str: + ans = "" + for k, v in self.norm_items(): + norm_value = "%.4g" % v + ans += str(k) + "=" + str(norm_value) + ", " + frames = str(self["frames"]) + ans += "over " + frames + " frames." + return ans + + def norm_items(self) -> List[Tuple[str, float]]: + """ + Returns a list of pairs, like: + [('ctc_loss', 0.1), ('att_loss', 0.07)] + """ + num_frames = self["frames"] if "frames" in self else 1 + ans = [] + for k, v in self.items(): + if k != "frames": + norm_value = float(v) / num_frames + ans.append((k, norm_value)) + return ans + + def reduce(self, device): + """ + Reduce using torch.distributed, which I believe ensures that + all processes get the total. + """ + keys = sorted(self.keys()) + s = torch.tensor([float(self[k]) for k in keys], device=device) + dist.all_reduce(s, op=dist.ReduceOp.SUM) + for k, v in zip(keys, s.cpu().tolist()): + self[k] = v + + def write_summary( + self, + tb_writer: SummaryWriter, + prefix: str, + batch_idx: int, + ) -> None: + """Add logging information to a TensorBoard writer. + + Args: + tb_writer: a TensorBoard writer + prefix: a prefix for the name of the loss, e.g. "train/valid_", + or "train/current_" + batch_idx: The current batch index, used as the x-axis of the plot. + """ + for k, v in self.norm_items(): + tb_writer.add_scalar(prefix + k, v, batch_idx) From 391432b35644a9cdf93b55351b1950d666eea256 Mon Sep 17 00:00:00 2001 From: Mingshuang Luo <37799481+luomingshuang@users.noreply.github.com> Date: Tue, 12 Oct 2021 21:30:31 +0800 Subject: [PATCH 08/14] Update train.py ("10"--->"params.log_interval") (#76) * Update train.py * Update train.py * Update train.py --- egs/librispeech/ASR/conformer_ctc/train.py | 2 +- egs/librispeech/ASR/tdnn_lstm_ctc/train.py | 2 +- egs/yesno/ASR/tdnn/train.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/egs/librispeech/ASR/conformer_ctc/train.py b/egs/librispeech/ASR/conformer_ctc/train.py index 3e1049fbf..5554aaa7c 100755 --- a/egs/librispeech/ASR/conformer_ctc/train.py +++ b/egs/librispeech/ASR/conformer_ctc/train.py @@ -485,7 +485,7 @@ def train_one_epoch( f"tot_loss[{tot_loss}], batch size: {batch_size}" ) - if batch_idx % 10 == 0: + if batch_idx % params.log_interval == 0: if tb_writer is not None: loss_info.write_summary( diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/train.py b/egs/librispeech/ASR/tdnn_lstm_ctc/train.py index 51a486e07..4a8574019 100755 --- a/egs/librispeech/ASR/tdnn_lstm_ctc/train.py +++ b/egs/librispeech/ASR/tdnn_lstm_ctc/train.py @@ -435,7 +435,7 @@ def train_one_epoch( f"batch {batch_idx}, loss[{loss_info}], " f"tot_loss[{tot_loss}], batch size: {batch_size}" ) - if batch_idx % 10 == 0: + if batch_idx % params.log_interval == 0: if tb_writer is not None: loss_info.write_summary( diff --git a/egs/yesno/ASR/tdnn/train.py b/egs/yesno/ASR/tdnn/train.py index 6cc511a28..d414962ca 100755 --- a/egs/yesno/ASR/tdnn/train.py +++ b/egs/yesno/ASR/tdnn/train.py @@ -416,7 +416,7 @@ def train_one_epoch( f"batch {batch_idx}, loss[{loss_info}], " f"tot_loss[{tot_loss}], batch size: {batch_size}" ) - if batch_idx % 10 == 0: + if batch_idx % params.log_interval == 0: if tb_writer is not None: loss_info.write_summary( From 39bc8cae94cb3b5824a93b5033136fba546322b9 Mon Sep 17 00:00:00 2001 From: Mingshuang Luo <37799481+luomingshuang@users.noreply.github.com> Date: Wed, 13 Oct 2021 12:20:16 +0800 Subject: [PATCH 09/14] Add ctc decoding to pretrained.py on conformer_ctc (#75) * Add ctc-decoding to pretrained.py * update pretrained.py and conformer_ctc.rst * update ctc-decoding for pretrained.py on conformer_ctc * Update pretrained.py * fix the style issue * Update conformer_ctc.rst * Update the running logs --- .../recipes/librispeech/conformer_ctc.rst | 119 +++++++---- .../ASR/conformer_ctc/pretrained.py | 202 +++++++++++------- 2 files changed, 211 insertions(+), 110 deletions(-) diff --git a/docs/source/recipes/librispeech/conformer_ctc.rst b/docs/source/recipes/librispeech/conformer_ctc.rst index 84e99306f..45ad79313 100644 --- a/docs/source/recipes/librispeech/conformer_ctc.rst +++ b/docs/source/recipes/librispeech/conformer_ctc.rst @@ -429,6 +429,7 @@ After downloading, you will have the following files: |-- README.md |-- data | |-- lang_bpe + | | |-- Linv.pt | | |-- HLG.pt | | |-- bpe.model | | |-- tokens.txt @@ -446,6 +447,9 @@ After downloading, you will have the following files: 6 directories, 11 files **File descriptions**: + - ``data/lang_bpe/Linv.pt`` + + It is the lexicon file, with word IDs as labels and token IDs as aux_labels. - ``data/lang_bpe/HLG.pt`` @@ -527,12 +531,58 @@ Usage displays the help information. -It supports three decoding methods: +It supports 4 decoding methods: + - CTC decoding - HLG decoding - HLG + n-gram LM rescoring - HLG + n-gram LM rescoring + attention decoder rescoring +CTC decoding +^^^^^^^^^^^^ + +CTC decoding uses the best path of the decoding lattice as the decoding result +without any LM or lexicon. + +The command to run CTC decoding is: + +.. code-block:: bash + + $ cd egs/librispeech/ASR + $ ./conformer_ctc/pretrained.py \ + --checkpoint ./tmp/icefall_asr_librispeech_conformer_ctc/exp/pretrained.pt \ + --lang-dir ./tmp/icefall_asr_librispeech_conformer_ctc/data/lang_bpe \ + --method ctc-decoding \ + ./tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1089-134686-0001.flac \ + ./tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1221-135766-0001.flac \ + ./tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1221-135766-0002.flac + +The output is given below: + +.. code-block:: + + 2021-10-13 11:21:50,896 INFO [pretrained.py:236] device: cuda:0 + 2021-10-13 11:21:50,896 INFO [pretrained.py:238] Creating model + 2021-10-13 11:21:56,669 INFO [pretrained.py:255] Constructing Fbank computer + 2021-10-13 11:21:56,670 INFO [pretrained.py:265] Reading sound files: ['./tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1089-134686-0001.flac', './tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1221-135766-0001.flac', './tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1221-135766-0002.flac'] + 2021-10-13 11:21:56,683 INFO [pretrained.py:271] Decoding started + 2021-10-13 11:21:57,341 INFO [pretrained.py:290] Building CTC topology + 2021-10-13 11:21:57,625 INFO [lexicon.py:113] Loading pre-compiled tmp/icefall_asr_librispeech_conformer_ctc/data/lang_bpe/Linv.pt + 2021-10-13 11:21:57,679 INFO [pretrained.py:299] Loading BPE model + 2021-10-13 11:22:00,076 INFO [pretrained.py:314] Use CTC decoding + 2021-10-13 11:22:00,087 INFO [pretrained.py:400] + ./tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1089-134686-0001.flac: + AFTER EARLY NIGHTFALL THE YELLOW LAMPS WOULD LIGHT UP HERE AND THERE THE SQUALID QUARTER OF THE BROTHELS + + ./tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1221-135766-0001.flac: + GOD AS A DIRECT CONSEQUENCE OF THE SIN WHICH MAN THUS PUNISHED HAD GIVEN HER A LOVELY CHILD WHOSE PLACE WAS ON THAT SAME DISHONOURED + BOSOM TO CONNECT HER PARENT FOR EVER WITH THE RACE AND DESCENT OF MORTALS AND TO BE FINALLY A BLESSED SOUL IN HEAVEN + + ./tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1221-135766-0002.flac: + YET THESE THOUGHTS AFFECTED HESTER PRYNNE LESS WITH HOPE THAN APPREHENSION + + 2021-10-13 11:22:00,087 INFO [pretrained.py:402] Decoding Done + HLG decoding ^^^^^^^^^^^^ @@ -545,8 +595,7 @@ The command to run HLG decoding is: $ cd egs/librispeech/ASR $ ./conformer_ctc/pretrained.py \ --checkpoint ./tmp/icefall_asr_librispeech_conformer_ctc/exp/pretrained.pt \ - --words-file ./tmp/icefall_asr_librispeech_conformer_ctc/data/lang_bpe/words.txt \ - --HLG ./tmp/icefall_asr_librispeech_conformer_ctc/data/lang_bpe/HLG.pt \ + --lang-dir ./tmp/icefall_asr_librispeech_conformer_ctc/data/lang_bpe \ ./tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1089-134686-0001.flac \ ./tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1221-135766-0001.flac \ ./tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1221-135766-0002.flac @@ -555,14 +604,14 @@ The output is given below: .. code-block:: - 2021-08-20 11:03:05,712 INFO [pretrained.py:217] device: cuda:0 - 2021-08-20 11:03:05,712 INFO [pretrained.py:219] Creating model - 2021-08-20 11:03:11,345 INFO [pretrained.py:238] Loading HLG from ./tmp/icefall_asr_librispeech_conformer_ctc/data/lang_bpe/HLG.pt - 2021-08-20 11:03:18,442 INFO [pretrained.py:255] Constructing Fbank computer - 2021-08-20 11:03:18,444 INFO [pretrained.py:265] Reading sound files: ['./tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1089-134686-0001.flac', './tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1221-135766-0001.flac', './tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1221-135766-0002.flac'] - 2021-08-20 11:03:18,507 INFO [pretrained.py:271] Decoding started - 2021-08-20 11:03:18,795 INFO [pretrained.py:300] Use HLG decoding - 2021-08-20 11:03:19,149 INFO [pretrained.py:339] + 2021-10-13 11:25:19,458 INFO [pretrained.py:236] device: cuda:0 + 2021-10-13 11:25:19,458 INFO [pretrained.py:238] Creating model + 2021-10-13 11:25:25,342 INFO [pretrained.py:255] Constructing Fbank computer + 2021-10-13 11:25:25,343 INFO [pretrained.py:265] Reading sound files: ['./tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1089-134686-0001.flac', './tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1221-135766-0001.flac', './tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1221-135766-0002.flac'] + 2021-10-13 11:25:25,356 INFO [pretrained.py:271] Decoding started + 2021-10-13 11:25:26,026 INFO [pretrained.py:327] Loading HLG from ./tmp/icefall_asr_librispeech_conformer_ctc/data/lang_bpe/HLG.pt + 2021-10-13 11:25:33,735 INFO [pretrained.py:359] Use HLG decoding + 2021-10-13 11:25:34,013 INFO [pretrained.py:400] ./tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1089-134686-0001.flac: AFTER EARLY NIGHTFALL THE YELLOW LAMPS WOULD LIGHT UP HERE AND THERE THE SQUALID QUARTER OF THE BROTHELS @@ -573,7 +622,7 @@ The output is given below: ./tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1221-135766-0002.flac: YET THESE THOUGHTS AFFECTED HESTER PRYNNE LESS WITH HOPE THAN APPREHENSION - 2021-08-20 11:03:19,149 INFO [pretrained.py:341] Decoding Done + 2021-10-13 11:25:34,014 INFO [pretrained.py:402] Decoding Done HLG decoding + LM rescoring ^^^^^^^^^^^^^^^^^^^^^^^^^^^ @@ -588,8 +637,7 @@ The command to run HLG decoding + LM rescoring is: $ cd egs/librispeech/ASR $ ./conformer_ctc/pretrained.py \ --checkpoint ./tmp/icefall_asr_librispeech_conformer_ctc/exp/pretrained.pt \ - --words-file ./tmp/icefall_asr_librispeech_conformer_ctc/data/lang_bpe/words.txt \ - --HLG ./tmp/icefall_asr_librispeech_conformer_ctc/data/lang_bpe/HLG.pt \ + --lang-dir ./tmp/icefall_asr_librispeech_conformer_ctc/data/lang_bpe \ --method whole-lattice-rescoring \ --G ./tmp/icefall_asr_librispeech_conformer_ctc/data/lm/G_4_gram.pt \ --ngram-lm-scale 0.8 \ @@ -601,15 +649,15 @@ Its output is: .. code-block:: - 2021-08-20 11:12:17,565 INFO [pretrained.py:217] device: cuda:0 - 2021-08-20 11:12:17,565 INFO [pretrained.py:219] Creating model - 2021-08-20 11:12:23,728 INFO [pretrained.py:238] Loading HLG from ./tmp/icefall_asr_librispeech_conformer_ctc/data/lang_bpe/HLG.pt - 2021-08-20 11:12:30,035 INFO [pretrained.py:246] Loading G from ./tmp/icefall_asr_librispeech_conformer_ctc/data/lm/G_4_gram.pt - 2021-08-20 11:13:10,779 INFO [pretrained.py:255] Constructing Fbank computer - 2021-08-20 11:13:10,787 INFO [pretrained.py:265] Reading sound files: ['./tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1089-134686-0001.flac', './tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1221-135766-0001.flac', './tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1221-135766-0002.flac'] - 2021-08-20 11:13:10,798 INFO [pretrained.py:271] Decoding started - 2021-08-20 11:13:11,085 INFO [pretrained.py:305] Use HLG decoding + LM rescoring - 2021-08-20 11:13:11,736 INFO [pretrained.py:339] + 2021-10-13 11:28:19,129 INFO [pretrained.py:236] device: cuda:0 + 2021-10-13 11:28:19,129 INFO [pretrained.py:238] Creating model + 2021-10-13 11:28:23,531 INFO [pretrained.py:255] Constructing Fbank computer + 2021-10-13 11:28:23,532 INFO [pretrained.py:265] Reading sound files: ['./tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1089-134686-0001.flac', './tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1221-135766-0001.flac', './tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1221-135766-0002.flac'] + 2021-10-13 11:28:23,544 INFO [pretrained.py:271] Decoding started + 2021-10-13 11:28:24,141 INFO [pretrained.py:327] Loading HLG from ./tmp/icefall_asr_librispeech_conformer_ctc/data/lang_bpe/HLG.pt + 2021-10-13 11:28:30,752 INFO [pretrained.py:338] Loading G from ./tmp/icefall_asr_librispeech_conformer_ctc/data/lm/G_4_gram.pt + 2021-10-13 11:28:48,308 INFO [pretrained.py:364] Use HLG decoding + LM rescoring + 2021-10-13 11:28:48,815 INFO [pretrained.py:400] ./tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1089-134686-0001.flac: AFTER EARLY NIGHTFALL THE YELLOW LAMPS WOULD LIGHT UP HERE AND THERE THE SQUALID QUARTER OF THE BROTHELS @@ -620,7 +668,7 @@ Its output is: ./tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1221-135766-0002.flac: YET THESE THOUGHTS AFFECTED HESTER PRYNNE LESS WITH HOPE THAN APPREHENSION - 2021-08-20 11:13:11,737 INFO [pretrained.py:341] Decoding Done + 2021-10-13 11:28:48,815 INFO [pretrained.py:402] Decoding Done HLG decoding + LM rescoring + attention decoder rescoring ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ @@ -636,8 +684,7 @@ The command to run HLG decoding + LM rescoring + attention decoder rescoring is: $ cd egs/librispeech/ASR $ ./conformer_ctc/pretrained.py \ --checkpoint ./tmp/icefall_asr_librispeech_conformer_ctc/exp/pretrained.pt \ - --words-file ./tmp/icefall_asr_librispeech_conformer_ctc/data/lang_bpe/words.txt \ - --HLG ./tmp/icefall_asr_librispeech_conformer_ctc/data/lang_bpe/HLG.pt \ + --lang-dir ./tmp/icefall_asr_librispeech_conformer_ctc/data/lang_bpe \ --method attention-decoder \ --G ./tmp/icefall_asr_librispeech_conformer_ctc/data/lm/G_4_gram.pt \ --ngram-lm-scale 1.3 \ @@ -654,15 +701,15 @@ The output is below: .. code-block:: - 2021-08-20 11:19:11,397 INFO [pretrained.py:217] device: cuda:0 - 2021-08-20 11:19:11,397 INFO [pretrained.py:219] Creating model - 2021-08-20 11:19:17,354 INFO [pretrained.py:238] Loading HLG from ./tmp/icefall_asr_librispeech_conformer_ctc/data/lang_bpe/HLG.pt - 2021-08-20 11:19:24,615 INFO [pretrained.py:246] Loading G from ./tmp/icefall_asr_librispeech_conformer_ctc/data/lm/G_4_gram.pt - 2021-08-20 11:20:04,576 INFO [pretrained.py:255] Constructing Fbank computer - 2021-08-20 11:20:04,584 INFO [pretrained.py:265] Reading sound files: ['./tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1089-134686-0001.flac', './tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1221-135766-0001.flac', './tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1221-135766-0002.flac'] - 2021-08-20 11:20:04,595 INFO [pretrained.py:271] Decoding started - 2021-08-20 11:20:04,854 INFO [pretrained.py:313] Use HLG + LM rescoring + attention decoder rescoring - 2021-08-20 11:20:05,805 INFO [pretrained.py:339] + 2021-10-13 11:29:50,106 INFO [pretrained.py:236] device: cuda:0 + 2021-10-13 11:29:50,106 INFO [pretrained.py:238] Creating model + 2021-10-13 11:29:56,063 INFO [pretrained.py:255] Constructing Fbank computer + 2021-10-13 11:29:56,063 INFO [pretrained.py:265] Reading sound files: ['./tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1089-134686-0001.flac', './tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1221-135766-0001.flac', './tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1221-135766-0002.flac'] + 2021-10-13 11:29:56,077 INFO [pretrained.py:271] Decoding started + 2021-10-13 11:29:56,770 INFO [pretrained.py:327] Loading HLG from ./tmp/icefall_asr_librispeech_conformer_ctc/data/lang_bpe/HLG.pt + 2021-10-13 11:30:04,023 INFO [pretrained.py:338] Loading G from ./tmp/icefall_asr_librispeech_conformer_ctc/data/lm/G_4_gram.pt + 2021-10-13 11:30:18,163 INFO [pretrained.py:372] Use HLG + LM rescoring + attention decoder rescoring + 2021-10-13 11:30:19,367 INFO [pretrained.py:400] ./tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1089-134686-0001.flac: AFTER EARLY NIGHTFALL THE YELLOW LAMPS WOULD LIGHT UP HERE AND THERE THE SQUALID QUARTER OF THE BROTHELS @@ -673,7 +720,7 @@ The output is below: ./tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1221-135766-0002.flac: YET THESE THOUGHTS AFFECTED HESTER PRYNNE LESS WITH HOPE THAN APPREHENSION - 2021-08-20 11:20:05,805 INFO [pretrained.py:341] Decoding Done + 2021-10-13 11:30:19,367 INFO [pretrained.py:402] Decoding Done Colab notebook -------------- diff --git a/egs/librispeech/ASR/conformer_ctc/pretrained.py b/egs/librispeech/ASR/conformer_ctc/pretrained.py index 00812d674..07d3e7269 100755 --- a/egs/librispeech/ASR/conformer_ctc/pretrained.py +++ b/egs/librispeech/ASR/conformer_ctc/pretrained.py @@ -1,5 +1,6 @@ #!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, +# Mingshuang Luo) # # See ../../../../LICENSE for clarification regarding multiple authors # @@ -19,6 +20,7 @@ import argparse import logging import math +import sentencepiece as spm from typing import List import k2 @@ -28,6 +30,7 @@ import torchaudio from conformer import Conformer from torch.nn.utils.rnn import pad_sequence +from icefall.lexicon import Lexicon from icefall.decode import ( get_lattice, one_best_decoding, @@ -52,14 +55,10 @@ def get_parser(): ) parser.add_argument( - "--words-file", + "--lang-dir", type=str, required=True, - help="Path to words.txt", - ) - - parser.add_argument( - "--HLG", type=str, required=True, help="Path to HLG.pt." + help="Path to lang bpe dir.", ) parser.add_argument( @@ -68,6 +67,10 @@ def get_parser(): default="1best", help="""Decoding method. Possible values are: + (0) ctc-decoding - Use CTC decoding. It uses a sentence + piece model, i.e., lang_dir/bpe.model, to convert + word pieces to words. It needs neither a lexicon + nor an n-gram LM. (1) 1best - Use the best path as decoding output. Only the transformer encoder output is used for decoding. We call it HLG decoding. @@ -249,23 +252,6 @@ def main(): model.to(device) model.eval() - logging.info(f"Loading HLG from {params.HLG}") - HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu")) - HLG = HLG.to(device) - if not hasattr(HLG, "lm_scores"): - # For whole-lattice-rescoring and attention-decoder - HLG.lm_scores = HLG.scores.clone() - - if params.method in ["whole-lattice-rescoring", "attention-decoder"]: - logging.info(f"Loading G from {params.G}") - G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu")) - # Add epsilon self-loops to G as we will compose - # it with the whole lattice later - G = G.to(device) - G = k2.add_epsilon_self_loops(G) - G = k2.arc_sort(G) - G.lm_scores = G.scores.clone() - logging.info("Constructing Fbank computer") opts = kaldifeat.FbankOptions() opts.device = device @@ -299,60 +285,128 @@ def main(): dtype=torch.int32, ) - lattice = get_lattice( - nnet_output=nnet_output, - decoding_graph=HLG, - supervision_segments=supervision_segments, - search_beam=params.search_beam, - output_beam=params.output_beam, - min_active_states=params.min_active_states, - max_active_states=params.max_active_states, - subsampling_factor=params.subsampling_factor, - ) + try: + if params.method == "ctc-decoding": + logging.info("Building CTC topology") + lexicon = Lexicon(params.lang_dir) + max_token_id = max(lexicon.tokens) + H = k2.ctc_topo( + max_token=max_token_id, + modified=False, + device=device, + ) - if params.method == "1best": - logging.info("Use HLG decoding") - best_path = one_best_decoding( - lattice=lattice, use_double_scores=params.use_double_scores - ) - elif params.method == "whole-lattice-rescoring": - logging.info("Use HLG decoding + LM rescoring") - best_path_dict = rescore_with_whole_lattice( - lattice=lattice, - G_with_epsilon_loops=G, - lm_scale_list=[params.ngram_lm_scale], - ) - best_path = next(iter(best_path_dict.values())) - elif params.method == "attention-decoder": - logging.info("Use HLG + LM rescoring + attention decoder rescoring") - rescored_lattice = rescore_with_whole_lattice( - lattice=lattice, G_with_epsilon_loops=G, lm_scale_list=None - ) - best_path_dict = rescore_with_attention_decoder( - lattice=rescored_lattice, - num_paths=params.num_paths, - model=model, - memory=memory, - memory_key_padding_mask=memory_key_padding_mask, - sos_id=params.sos_id, - eos_id=params.eos_id, - nbest_scale=params.nbest_scale, - ngram_lm_scale=params.ngram_lm_scale, - attention_scale=params.attention_decoder_scale, - ) - best_path = next(iter(best_path_dict.values())) + logging.info("Loading BPE model") + bpe_model = spm.SentencePieceProcessor() + bpe_model.load(params.lang_dir + "/bpe.model") - hyps = get_texts(best_path) - word_sym_table = k2.SymbolTable.from_file(params.words_file) - hyps = [[word_sym_table[i] for i in ids] for ids in hyps] + lattice = get_lattice( + nnet_output=nnet_output, + decoding_graph=H, + supervision_segments=supervision_segments, + search_beam=params.search_beam, + output_beam=params.output_beam, + min_active_states=params.min_active_states, + max_active_states=params.max_active_states, + subsampling_factor=params.subsampling_factor, + ) - s = "\n" - for filename, hyp in zip(params.sound_files, hyps): - words = " ".join(hyp) - s += f"{filename}:\n{words}\n\n" - logging.info(s) + logging.info("Use CTC decoding") + best_path = one_best_decoding( + lattice=lattice, use_double_scores=params.use_double_scores + ) + token_ids = get_texts(best_path) + hyps = bpe_model.decode(token_ids) + hyps = [s.split() for s in hyps] - logging.info("Decoding Done") + if params.method in [ + "1best", + "whole-lattice-rescoring", + "attention-decoder", + ]: + logging.info(f"Loading HLG from {params.lang_dir}/HLG.pt") + HLG = k2.Fsa.from_dict( + torch.load(params.lang_dir + "/HLG.pt", map_location="cpu") + ) + HLG = HLG.to(device) + if not hasattr(HLG, "lm_scores"): + # For whole-lattice-rescoring and attention-decoder + HLG.lm_scores = HLG.scores.clone() + + if params.method in [ + "whole-lattice-rescoring", + "attention-decoder", + ]: + logging.info(f"Loading G from {params.G}") + G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu")) + # Add epsilon self-loops to G as we will compose + # it with the whole lattice later + G = G.to(device) + G = k2.add_epsilon_self_loops(G) + G = k2.arc_sort(G) + G.lm_scores = G.scores.clone() + + lattice = get_lattice( + nnet_output=nnet_output, + decoding_graph=HLG, + supervision_segments=supervision_segments, + search_beam=params.search_beam, + output_beam=params.output_beam, + min_active_states=params.min_active_states, + max_active_states=params.max_active_states, + subsampling_factor=params.subsampling_factor, + ) + + if params.method == "1best": + logging.info("Use HLG decoding") + best_path = one_best_decoding( + lattice=lattice, use_double_scores=params.use_double_scores + ) + elif params.method == "whole-lattice-rescoring": + logging.info("Use HLG decoding + LM rescoring") + best_path_dict = rescore_with_whole_lattice( + lattice=lattice, + G_with_epsilon_loops=G, + lm_scale_list=[params.ngram_lm_scale], + ) + best_path = next(iter(best_path_dict.values())) + elif params.method == "attention-decoder": + logging.info( + "Use HLG + LM rescoring + attention decoder rescoring" + ) + rescored_lattice = rescore_with_whole_lattice( + lattice=lattice, G_with_epsilon_loops=G, lm_scale_list=None + ) + best_path_dict = rescore_with_attention_decoder( + lattice=rescored_lattice, + num_paths=params.num_paths, + model=model, + memory=memory, + memory_key_padding_mask=memory_key_padding_mask, + sos_id=params.sos_id, + eos_id=params.eos_id, + nbest_scale=params.nbest_scale, + ngram_lm_scale=params.ngram_lm_scale, + attention_scale=params.attention_decoder_scale, + ) + best_path = next(iter(best_path_dict.values())) + + hyps = get_texts(best_path) + word_sym_table = k2.SymbolTable.from_file( + params.lang_dir + "/words.txt" + ) + hyps = [[word_sym_table[i] for i in ids] for ids in hyps] + + s = "\n" + for filename, hyp in zip(params.sound_files, hyps): + words = " ".join(hyp) + s += f"{filename}:\n{words}\n\n" + logging.info(s) + + logging.info("Decoding Done") + + except Exception: + raise ValueError("Please use a supported decoding method.") if __name__ == "__main__": From 5016ee3c95551842dc04333f12f5ca5791256ec1 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Thu, 14 Oct 2021 16:20:35 +0800 Subject: [PATCH 10/14] Give an informative message when users provide an unsupported decoding method (#77) --- .../ASR/conformer_ctc/pretrained.py | 219 +++++++++--------- 1 file changed, 106 insertions(+), 113 deletions(-) diff --git a/egs/librispeech/ASR/conformer_ctc/pretrained.py b/egs/librispeech/ASR/conformer_ctc/pretrained.py index 07d3e7269..be94e6875 100755 --- a/egs/librispeech/ASR/conformer_ctc/pretrained.py +++ b/egs/librispeech/ASR/conformer_ctc/pretrained.py @@ -20,23 +20,23 @@ import argparse import logging import math -import sentencepiece as spm from typing import List import k2 import kaldifeat +import sentencepiece as spm import torch import torchaudio from conformer import Conformer from torch.nn.utils.rnn import pad_sequence -from icefall.lexicon import Lexicon from icefall.decode import ( get_lattice, one_best_decoding, rescore_with_attention_decoder, rescore_with_whole_lattice, ) +from icefall.lexicon import Lexicon from icefall.utils import AttributeDict, get_texts @@ -58,7 +58,7 @@ def get_parser(): "--lang-dir", type=str, required=True, - help="Path to lang bpe dir.", + help="Path to lang dir.", ) parser.add_argument( @@ -142,7 +142,7 @@ def get_parser(): parser.add_argument( "--sos-id", - type=float, + type=int, default=1, help=""" Used only when method is attention-decoder. @@ -152,7 +152,7 @@ def get_parser(): parser.add_argument( "--eos-id", - type=float, + type=int, default=1, help=""" Used only when method is attention-decoder. @@ -285,128 +285,121 @@ def main(): dtype=torch.int32, ) - try: - if params.method == "ctc-decoding": - logging.info("Building CTC topology") - lexicon = Lexicon(params.lang_dir) - max_token_id = max(lexicon.tokens) - H = k2.ctc_topo( - max_token=max_token_id, - modified=False, - device=device, - ) + if params.method == "ctc-decoding": + logging.info("Use CTC decoding") + lexicon = Lexicon(params.lang_dir) + max_token_id = max(lexicon.tokens) + H = k2.ctc_topo( + max_token=max_token_id, + modified=False, + device=device, + ) - logging.info("Loading BPE model") - bpe_model = spm.SentencePieceProcessor() - bpe_model.load(params.lang_dir + "/bpe.model") + bpe_model = spm.SentencePieceProcessor() + bpe_model.load(params.lang_dir + "/bpe.model") - lattice = get_lattice( - nnet_output=nnet_output, - decoding_graph=H, - supervision_segments=supervision_segments, - search_beam=params.search_beam, - output_beam=params.output_beam, - min_active_states=params.min_active_states, - max_active_states=params.max_active_states, - subsampling_factor=params.subsampling_factor, - ) + lattice = get_lattice( + nnet_output=nnet_output, + decoding_graph=H, + supervision_segments=supervision_segments, + search_beam=params.search_beam, + output_beam=params.output_beam, + min_active_states=params.min_active_states, + max_active_states=params.max_active_states, + subsampling_factor=params.subsampling_factor, + ) - logging.info("Use CTC decoding") - best_path = one_best_decoding( - lattice=lattice, use_double_scores=params.use_double_scores - ) - token_ids = get_texts(best_path) - hyps = bpe_model.decode(token_ids) - hyps = [s.split() for s in hyps] + best_path = one_best_decoding( + lattice=lattice, use_double_scores=params.use_double_scores + ) + token_ids = get_texts(best_path) + hyps = bpe_model.decode(token_ids) + hyps = [s.split() for s in hyps] + elif params.method in [ + "1best", + "whole-lattice-rescoring", + "attention-decoder", + ]: + logging.info(f"Loading HLG from {params.lang_dir}/HLG.pt") + HLG = k2.Fsa.from_dict( + torch.load(params.lang_dir + "/HLG.pt", map_location="cpu") + ) + HLG = HLG.to(device) + if not hasattr(HLG, "lm_scores"): + # For whole-lattice-rescoring and attention-decoder + HLG.lm_scores = HLG.scores.clone() if params.method in [ - "1best", "whole-lattice-rescoring", "attention-decoder", ]: - logging.info(f"Loading HLG from {params.lang_dir}/HLG.pt") - HLG = k2.Fsa.from_dict( - torch.load(params.lang_dir + "/HLG.pt", map_location="cpu") + logging.info(f"Loading G from {params.G}") + G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu")) + # Add epsilon self-loops to G as we will compose + # it with the whole lattice later + G = G.to(device) + G = k2.add_epsilon_self_loops(G) + G = k2.arc_sort(G) + G.lm_scores = G.scores.clone() + + lattice = get_lattice( + nnet_output=nnet_output, + decoding_graph=HLG, + supervision_segments=supervision_segments, + search_beam=params.search_beam, + output_beam=params.output_beam, + min_active_states=params.min_active_states, + max_active_states=params.max_active_states, + subsampling_factor=params.subsampling_factor, + ) + + if params.method == "1best": + logging.info("Use HLG decoding") + best_path = one_best_decoding( + lattice=lattice, use_double_scores=params.use_double_scores ) - HLG = HLG.to(device) - if not hasattr(HLG, "lm_scores"): - # For whole-lattice-rescoring and attention-decoder - HLG.lm_scores = HLG.scores.clone() - - if params.method in [ - "whole-lattice-rescoring", - "attention-decoder", - ]: - logging.info(f"Loading G from {params.G}") - G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu")) - # Add epsilon self-loops to G as we will compose - # it with the whole lattice later - G = G.to(device) - G = k2.add_epsilon_self_loops(G) - G = k2.arc_sort(G) - G.lm_scores = G.scores.clone() - - lattice = get_lattice( - nnet_output=nnet_output, - decoding_graph=HLG, - supervision_segments=supervision_segments, - search_beam=params.search_beam, - output_beam=params.output_beam, - min_active_states=params.min_active_states, - max_active_states=params.max_active_states, - subsampling_factor=params.subsampling_factor, + elif params.method == "whole-lattice-rescoring": + logging.info("Use HLG decoding + LM rescoring") + best_path_dict = rescore_with_whole_lattice( + lattice=lattice, + G_with_epsilon_loops=G, + lm_scale_list=[params.ngram_lm_scale], ) - - if params.method == "1best": - logging.info("Use HLG decoding") - best_path = one_best_decoding( - lattice=lattice, use_double_scores=params.use_double_scores - ) - elif params.method == "whole-lattice-rescoring": - logging.info("Use HLG decoding + LM rescoring") - best_path_dict = rescore_with_whole_lattice( - lattice=lattice, - G_with_epsilon_loops=G, - lm_scale_list=[params.ngram_lm_scale], - ) - best_path = next(iter(best_path_dict.values())) - elif params.method == "attention-decoder": - logging.info( - "Use HLG + LM rescoring + attention decoder rescoring" - ) - rescored_lattice = rescore_with_whole_lattice( - lattice=lattice, G_with_epsilon_loops=G, lm_scale_list=None - ) - best_path_dict = rescore_with_attention_decoder( - lattice=rescored_lattice, - num_paths=params.num_paths, - model=model, - memory=memory, - memory_key_padding_mask=memory_key_padding_mask, - sos_id=params.sos_id, - eos_id=params.eos_id, - nbest_scale=params.nbest_scale, - ngram_lm_scale=params.ngram_lm_scale, - attention_scale=params.attention_decoder_scale, - ) - best_path = next(iter(best_path_dict.values())) - - hyps = get_texts(best_path) - word_sym_table = k2.SymbolTable.from_file( - params.lang_dir + "/words.txt" + best_path = next(iter(best_path_dict.values())) + elif params.method == "attention-decoder": + logging.info("Use HLG + LM rescoring + attention decoder rescoring") + rescored_lattice = rescore_with_whole_lattice( + lattice=lattice, G_with_epsilon_loops=G, lm_scale_list=None ) - hyps = [[word_sym_table[i] for i in ids] for ids in hyps] + best_path_dict = rescore_with_attention_decoder( + lattice=rescored_lattice, + num_paths=params.num_paths, + model=model, + memory=memory, + memory_key_padding_mask=memory_key_padding_mask, + sos_id=params.sos_id, + eos_id=params.eos_id, + nbest_scale=params.nbest_scale, + ngram_lm_scale=params.ngram_lm_scale, + attention_scale=params.attention_decoder_scale, + ) + best_path = next(iter(best_path_dict.values())) - s = "\n" - for filename, hyp in zip(params.sound_files, hyps): - words = " ".join(hyp) - s += f"{filename}:\n{words}\n\n" - logging.info(s) + hyps = get_texts(best_path) + word_sym_table = k2.SymbolTable.from_file( + params.lang_dir + "/words.txt" + ) + hyps = [[word_sym_table[i] for i in ids] for ids in hyps] + else: + raise ValueError(f"Unsupported decoding method: {params.method}") - logging.info("Decoding Done") + s = "\n" + for filename, hyp in zip(params.sound_files, hyps): + words = " ".join(hyp) + s += f"{filename}:\n{words}\n\n" + logging.info(s) - except Exception: - raise ValueError("Please use a supported decoding method.") + logging.info("Decoding Done") if __name__ == "__main__": From f2387fe523c6f89987f3723bfa967095e8de5127 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Thu, 14 Oct 2021 20:09:38 +0800 Subject: [PATCH 11/14] Fix a bug introduced while supporting torch script. (#79) --- egs/librispeech/ASR/conformer_ctc/transformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/conformer_ctc/transformer.py b/egs/librispeech/ASR/conformer_ctc/transformer.py index a2e36a41e..3e6abb695 100644 --- a/egs/librispeech/ASR/conformer_ctc/transformer.py +++ b/egs/librispeech/ASR/conformer_ctc/transformer.py @@ -660,7 +660,7 @@ class PositionalEncoding(nn.Module): self.xscale = math.sqrt(self.d_model) self.dropout = nn.Dropout(p=dropout) # not doing: self.pe = None because of errors thrown by torchscript - self.pe = torch.zeros(0, self.d_model, dtype=torch.float32) + self.pe = torch.zeros(0, 0, dtype=torch.float32) def extend_pe(self, x: torch.Tensor) -> None: """Extend the time t in the positional encoding if required. From 5401ce199d271003e79e1cf13597851a45fc9b3e Mon Sep 17 00:00:00 2001 From: Mingshuang Luo <37799481+luomingshuang@users.noreply.github.com> Date: Thu, 14 Oct 2021 23:29:06 +0800 Subject: [PATCH 12/14] Update ctc-decoding on pretrained.py and conformer_ctc.rst (#78) --- .../recipes/librispeech/conformer_ctc.rst | 16 +++---- .../ASR/conformer_ctc/pretrained.py | 43 ++++++++++++------- 2 files changed, 34 insertions(+), 25 deletions(-) diff --git a/docs/source/recipes/librispeech/conformer_ctc.rst b/docs/source/recipes/librispeech/conformer_ctc.rst index 45ad79313..2a956750f 100644 --- a/docs/source/recipes/librispeech/conformer_ctc.rst +++ b/docs/source/recipes/librispeech/conformer_ctc.rst @@ -429,7 +429,6 @@ After downloading, you will have the following files: |-- README.md |-- data | |-- lang_bpe - | | |-- Linv.pt | | |-- HLG.pt | | |-- bpe.model | | |-- tokens.txt @@ -447,10 +446,6 @@ After downloading, you will have the following files: 6 directories, 11 files **File descriptions**: - - ``data/lang_bpe/Linv.pt`` - - It is the lexicon file, with word IDs as labels and token IDs as aux_labels. - - ``data/lang_bpe/HLG.pt`` It is the decoding graph. @@ -551,7 +546,7 @@ The command to run CTC decoding is: $ cd egs/librispeech/ASR $ ./conformer_ctc/pretrained.py \ --checkpoint ./tmp/icefall_asr_librispeech_conformer_ctc/exp/pretrained.pt \ - --lang-dir ./tmp/icefall_asr_librispeech_conformer_ctc/data/lang_bpe \ + --bpe-model ./tmp/icefall_asr_librispeech_conformer_ctc/data/lang_bpe/bpe.model \ --method ctc-decoding \ ./tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1089-134686-0001.flac \ ./tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1221-135766-0001.flac \ @@ -595,7 +590,8 @@ The command to run HLG decoding is: $ cd egs/librispeech/ASR $ ./conformer_ctc/pretrained.py \ --checkpoint ./tmp/icefall_asr_librispeech_conformer_ctc/exp/pretrained.pt \ - --lang-dir ./tmp/icefall_asr_librispeech_conformer_ctc/data/lang_bpe \ + --words-file ./tmp/icefall_asr_librispeech_conformer_ctc/data/lang_bpe/words.txt \ + --HLG ./tmp/icefall_asr_librispeech_conformer_ctc/data/lang_bpe/HLG.pt \ ./tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1089-134686-0001.flac \ ./tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1221-135766-0001.flac \ ./tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1221-135766-0002.flac @@ -637,7 +633,8 @@ The command to run HLG decoding + LM rescoring is: $ cd egs/librispeech/ASR $ ./conformer_ctc/pretrained.py \ --checkpoint ./tmp/icefall_asr_librispeech_conformer_ctc/exp/pretrained.pt \ - --lang-dir ./tmp/icefall_asr_librispeech_conformer_ctc/data/lang_bpe \ + --words-file ./tmp/icefall_asr_librispeech_conformer_ctc/data/lang_bpe/words.txt \ + --HLG ./tmp/icefall_asr_librispeech_conformer_ctc/data/lang_bpe/HLG.pt \ --method whole-lattice-rescoring \ --G ./tmp/icefall_asr_librispeech_conformer_ctc/data/lm/G_4_gram.pt \ --ngram-lm-scale 0.8 \ @@ -684,7 +681,8 @@ The command to run HLG decoding + LM rescoring + attention decoder rescoring is: $ cd egs/librispeech/ASR $ ./conformer_ctc/pretrained.py \ --checkpoint ./tmp/icefall_asr_librispeech_conformer_ctc/exp/pretrained.pt \ - --lang-dir ./tmp/icefall_asr_librispeech_conformer_ctc/data/lang_bpe \ + --words-file ./tmp/icefall_asr_librispeech_conformer_ctc/data/lang_bpe/words.txt \ + --HLG ./tmp/icefall_asr_librispeech_conformer_ctc/data/lang_bpe/HLG.pt \ --method attention-decoder \ --G ./tmp/icefall_asr_librispeech_conformer_ctc/data/lm/G_4_gram.pt \ --ngram-lm-scale 1.3 \ diff --git a/egs/librispeech/ASR/conformer_ctc/pretrained.py b/egs/librispeech/ASR/conformer_ctc/pretrained.py index be94e6875..edbdb5b2e 100755 --- a/egs/librispeech/ASR/conformer_ctc/pretrained.py +++ b/egs/librispeech/ASR/conformer_ctc/pretrained.py @@ -36,7 +36,6 @@ from icefall.decode import ( rescore_with_attention_decoder, rescore_with_whole_lattice, ) -from icefall.lexicon import Lexicon from icefall.utils import AttributeDict, get_texts @@ -55,10 +54,27 @@ def get_parser(): ) parser.add_argument( - "--lang-dir", + "--words-file", type=str, - required=True, - help="Path to lang dir.", + help="""Path to words.txt. + Used only when method is not ctc-decoding. + """, + ) + + parser.add_argument( + "--HLG", + type=str, + help="""Path to HLG.pt. + Used only when method is not ctc-decoding. + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + help="""Path to bpe.model. + Used only when method is ctc-decoding. + """, ) parser.add_argument( @@ -287,17 +303,16 @@ def main(): if params.method == "ctc-decoding": logging.info("Use CTC decoding") - lexicon = Lexicon(params.lang_dir) - max_token_id = max(lexicon.tokens) + bpe_model = spm.SentencePieceProcessor() + bpe_model.load(params.bpe_model) + max_token_id = bpe_model.get_piece_size() - 1 + H = k2.ctc_topo( max_token=max_token_id, modified=False, device=device, ) - bpe_model = spm.SentencePieceProcessor() - bpe_model.load(params.lang_dir + "/bpe.model") - lattice = get_lattice( nnet_output=nnet_output, decoding_graph=H, @@ -320,10 +335,8 @@ def main(): "whole-lattice-rescoring", "attention-decoder", ]: - logging.info(f"Loading HLG from {params.lang_dir}/HLG.pt") - HLG = k2.Fsa.from_dict( - torch.load(params.lang_dir + "/HLG.pt", map_location="cpu") - ) + logging.info(f"Loading HLG from {params.HLG}") + HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu")) HLG = HLG.to(device) if not hasattr(HLG, "lm_scores"): # For whole-lattice-rescoring and attention-decoder @@ -386,9 +399,7 @@ def main(): best_path = next(iter(best_path_dict.values())) hyps = get_texts(best_path) - word_sym_table = k2.SymbolTable.from_file( - params.lang_dir + "/words.txt" - ) + word_sym_table = k2.SymbolTable.from_file(params.words_file) hyps = [[word_sym_table[i] for i in ids] for ids in hyps] else: raise ValueError(f"Unsupported decoding method: {params.method}") From fee1f84b20a5a704428c5eac80de2ac4033e1b27 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Fri, 15 Oct 2021 00:41:33 +0800 Subject: [PATCH 13/14] Test pre-trained model in CI (#80) * Add CI to run pre-trained models. * Minor fixes. * Install kaldifeat * Install a CPU version of PyTorch. * Fix CI errors. * Disable decoder layers in pretrained.py if it is not used. * Clone pre-trained model from GitHub. * Minor fixes. * Minor fixes. * Minor fixes. --- .github/workflows/run-pretrained.yml | 106 ++++++++++++++++++ .github/workflows/test.yml | 5 + .../ASR/conformer_ctc/pretrained.py | 20 +++- egs/librispeech/ASR/conformer_ctc/train.py | 30 +++-- 4 files changed, 150 insertions(+), 11 deletions(-) create mode 100644 .github/workflows/run-pretrained.yml diff --git a/.github/workflows/run-pretrained.yml b/.github/workflows/run-pretrained.yml new file mode 100644 index 000000000..97d3c32d2 --- /dev/null +++ b/.github/workflows/run-pretrained.yml @@ -0,0 +1,106 @@ +# Copyright 2021 Fangjun Kuang (csukuangfj@gmail.com) + +# See ../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +name: run-pre-trained-conformer-ctc + +on: + push: + branches: + - master + pull_request: + types: [labeled] + +jobs: + run_pre_trained_conformer_ctc: + if: github.event.label.name == 'ready' || github.event_name == 'push' + runs-on: ${{ matrix.os }} + strategy: + matrix: + os: [ubuntu-18.04] + python-version: [3.6, 3.7, 3.8, 3.9] + torch: ["1.8.1"] + k2-version: ["1.9.dev20210919"] + + fail-fast: false + + steps: + - uses: actions/checkout@v2 + with: + fetch-depth: 0 + + - name: Setup Python ${{ matrix.python-version }} + uses: actions/setup-python@v1 + with: + python-version: ${{ matrix.python-version }} + + - name: Install Python dependencies + run: | + python3 -m pip install --upgrade pip pytest + pip install torch==${{ matrix.torch }}+cpu -f https://download.pytorch.org/whl/torch_stable.html + pip install k2==${{ matrix.k2-version }}+cpu.torch${{ matrix.torch }} -f https://k2-fsa.org/nightly/ + + python3 -m pip install git+https://github.com/lhotse-speech/lhotse + python3 -m pip install kaldifeat + # We are in ./icefall and there is a file: requirements.txt in it + pip install -r requirements.txt + + - name: Install graphviz + shell: bash + run: | + python3 -m pip install -qq graphviz + sudo apt-get -qq install graphviz + + - name: Download pre-trained model + shell: bash + run: | + sudo apt-get -qq install git-lfs tree sox + cd egs/librispeech/ASR + mkdir tmp + cd tmp + git lfs install + git clone https://github.com/csukuangfj/icefall-asr-conformer-ctc-bpe-500 + cd .. + tree tmp + soxi tmp/icefall-asr-conformer-ctc-bpe-500/test_wavs/*.flac + ls -lh tmp/icefall-asr-conformer-ctc-bpe-500/test_wavs/*.flac + + - name: Run CTC decoding + shell: bash + run: | + export PYTHONPATH=$PWD:PYTHONPATH + cd egs/librispeech/ASR + ./conformer_ctc/pretrained.py \ + --num-classes 500 \ + --checkpoint ./tmp/icefall-asr-conformer-ctc-bpe-500/exp/pretrained.pt \ + --bpe-model ./tmp/icefall-asr-conformer-ctc-bpe-500/data/lang_bpe_500/bpe.model \ + --method ctc-decoding \ + ./tmp/icefall-asr-conformer-ctc-bpe-500/test_wavs/1089-134686-0001.flac \ + ./tmp/icefall-asr-conformer-ctc-bpe-500/test_wavs/1221-135766-0001.flac \ + ./tmp/icefall-asr-conformer-ctc-bpe-500/test_wavs/1221-135766-0002.flac + + - name: Run HLG decoding + shell: bash + run: | + export PYTHONPATH=$PWD:$PYTHONPATH + cd egs/librispeech/ASR + ./conformer_ctc/pretrained.py \ + --num-classes 500 \ + --checkpoint ./tmp/icefall-asr-conformer-ctc-bpe-500/exp/pretrained.pt \ + --words-file ./tmp/icefall-asr-conformer-ctc-bpe-500/data/lang_bpe_500/words.txt \ + --HLG ./tmp/icefall-asr-conformer-ctc-bpe-500/data/lang_bpe_500/HLG.pt \ + ./tmp/icefall-asr-conformer-ctc-bpe-500/test_wavs/1089-134686-0001.flac \ + ./tmp/icefall-asr-conformer-ctc-bpe-500/test_wavs/1221-135766-0001.flac \ + ./tmp/icefall-asr-conformer-ctc-bpe-500/test_wavs/1221-135766-0002.flac diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 150b5258a..c6114ce73 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -84,3 +84,8 @@ jobs: echo "lib_path: $lib_path" export DYLD_LIBRARY_PATH=$lib_path:$DYLD_LIBRARY_PATH pytest ./test + + # runt tests for conformer ctc + cd egs/librispeech/ASR/conformer_ctc + pytest + diff --git a/egs/librispeech/ASR/conformer_ctc/pretrained.py b/egs/librispeech/ASR/conformer_ctc/pretrained.py index edbdb5b2e..99bd9c017 100755 --- a/egs/librispeech/ASR/conformer_ctc/pretrained.py +++ b/egs/librispeech/ASR/conformer_ctc/pretrained.py @@ -166,6 +166,15 @@ def get_parser(): """, ) + parser.add_argument( + "--num-classes", + type=int, + default=5000, + help=""" + Vocab size in the BPE model. + """, + ) + parser.add_argument( "--eos-id", type=int, @@ -199,7 +208,6 @@ def get_params() -> AttributeDict: "use_feat_batchnorm": True, "feature_dim": 80, "nhead": 8, - "num_classes": 5000, "attention_dim": 512, "num_decoder_layers": 6, # parameters for decoding @@ -242,7 +250,13 @@ def main(): args = parser.parse_args() params = get_params() + if args.method != "attention-decoder": + # to save memory as the attention decoder + # will not be used + params.num_decoder_layers = 0 + params.update(vars(args)) + logging.info(f"{params}") device = torch.device("cpu") @@ -264,7 +278,7 @@ def main(): ) checkpoint = torch.load(args.checkpoint, map_location="cpu") - model.load_state_dict(checkpoint["model"]) + model.load_state_dict(checkpoint["model"], strict=False) model.to(device) model.eval() @@ -305,7 +319,7 @@ def main(): logging.info("Use CTC decoding") bpe_model = spm.SentencePieceProcessor() bpe_model.load(params.bpe_model) - max_token_id = bpe_model.get_piece_size() - 1 + max_token_id = params.num_classes - 1 H = k2.ctc_topo( max_token=max_token_id, diff --git a/egs/librispeech/ASR/conformer_ctc/train.py b/egs/librispeech/ASR/conformer_ctc/train.py index 5554aaa7c..d1cdfa8bb 100755 --- a/egs/librispeech/ASR/conformer_ctc/train.py +++ b/egs/librispeech/ASR/conformer_ctc/train.py @@ -96,6 +96,26 @@ def get_parser(): """, ) + parser.add_argument( + "--exp-dir", + type=str, + default="conformer_ctc/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--lang-dir", + type=str, + default="data/lang_bpe", + help="""The lang dir + It contains language related input files such as + "lexicon.txt" + """, + ) + return parser @@ -110,12 +130,6 @@ def get_params() -> AttributeDict: Explanation of options saved in `params`: - - exp_dir: It specifies the directory where all training related - files, e.g., checkpoints, log, etc, are saved - - - lang_dir: It contains language related input files such as - "lexicon.txt" - - best_train_loss: Best training loss so far. It is used to select the model that has the lowest training loss. It is updated during the training. @@ -166,8 +180,6 @@ def get_params() -> AttributeDict: """ params = AttributeDict( { - "exp_dir": Path("conformer_ctc/exp"), - "lang_dir": Path("data/lang_bpe"), "best_train_loss": float("inf"), "best_valid_loss": float("inf"), "best_train_epoch": -1, @@ -638,6 +650,8 @@ def main(): parser = get_parser() LibriSpeechAsrDataModule.add_arguments(parser) args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + args.lang_dir = Path(args.lang_dir) world_size = args.world_size assert world_size >= 1 From bd7c2f7645c0aea4cb482cc1f60836907a61d36b Mon Sep 17 00:00:00 2001 From: "Jan \"yenda\" Trmal" Date: Fri, 15 Oct 2021 19:46:17 -0400 Subject: [PATCH 14/14] fix conformer typo in docs (#83) --- docs/source/recipes/librispeech/conformer_ctc.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/recipes/librispeech/conformer_ctc.rst b/docs/source/recipes/librispeech/conformer_ctc.rst index 2a956750f..57ac246e1 100644 --- a/docs/source/recipes/librispeech/conformer_ctc.rst +++ b/docs/source/recipes/librispeech/conformer_ctc.rst @@ -1,4 +1,4 @@ -Confromer CTC +Conformer CTC ============= This tutorial shows you how to run a conformer ctc model