From 25fa6c0690a2649285956fbd2c031209cd3c247b Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Mon, 25 Oct 2021 11:24:26 +0800 Subject: [PATCH] Support switching modified/standard CTC topo from commandline. --- egs/librispeech/ASR/conformer_ctc/decode.py | 18 +++++-- .../ASR/conformer_ctc/pretrained.py | 14 ++++- egs/librispeech/ASR/conformer_ctc/train.py | 11 +++- egs/librispeech/ASR/local/compile_hlg.py | 54 ++++++++++++++----- egs/librispeech/ASR/prepare.sh | 2 + 5 files changed, 80 insertions(+), 19 deletions(-) diff --git a/egs/librispeech/ASR/conformer_ctc/decode.py b/egs/librispeech/ASR/conformer_ctc/decode.py index 3b1d34757..7c6c1726d 100755 --- a/egs/librispeech/ASR/conformer_ctc/decode.py +++ b/egs/librispeech/ASR/conformer_ctc/decode.py @@ -147,6 +147,13 @@ def get_parser(): help="The lang dir", ) + parser.add_argument( + "--modified-ctc-topo", + type=str2bool, + default=False, + help="True to use modified ctc topo.", + ) + return parser @@ -563,7 +570,7 @@ def main(): HLG = None H = k2.ctc_topo( max_token=max_token_id, - modified=False, + modified=params.modified_ctc_topo, device=device, ) bpe_model = spm.SentencePieceProcessor() @@ -571,9 +578,12 @@ def main(): else: H = None bpe_model = None - HLG = k2.Fsa.from_dict( - torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu") - ) + if params.modified_ctc_topo: + filename = params.lang_dir / "HLG_modified.pt" + else: + filename = params.lang_dir / "HLG.pt" + logging.info(f"Loading {filename}") + HLG = k2.Fsa.from_dict(torch.load(filename, map_location="cpu")) HLG = HLG.to(device) assert HLG.requires_grad is False diff --git a/egs/librispeech/ASR/conformer_ctc/pretrained.py b/egs/librispeech/ASR/conformer_ctc/pretrained.py index beed6f73b..eb5896c1a 100755 --- a/egs/librispeech/ASR/conformer_ctc/pretrained.py +++ b/egs/librispeech/ASR/conformer_ctc/pretrained.py @@ -36,7 +36,7 @@ from icefall.decode import ( rescore_with_attention_decoder, rescore_with_whole_lattice, ) -from icefall.utils import AttributeDict, get_env_info, get_texts +from icefall.utils import AttributeDict, get_env_info, get_texts, str2bool def get_parser(): @@ -195,6 +195,15 @@ def get_parser(): "The sample rate has to be 16kHz.", ) + parser.add_argument( + "--modified-ctc-topo", + type=str2bool, + default=False, + help="""True to use modified ctc topo. + Used only when method is ctc-decoding. + """, + ) + return parser @@ -321,9 +330,10 @@ def main(): bpe_model.load(params.bpe_model) max_token_id = params.num_classes - 1 + logging.info(f"modified_ctc_topo: {params.modified_ctc_topo}") H = k2.ctc_topo( max_token=max_token_id, - modified=False, + modified=params.modified_ctc_topo, device=device, ) diff --git a/egs/librispeech/ASR/conformer_ctc/train.py b/egs/librispeech/ASR/conformer_ctc/train.py index 223c8d993..2cc5a00a0 100755 --- a/egs/librispeech/ASR/conformer_ctc/train.py +++ b/egs/librispeech/ASR/conformer_ctc/train.py @@ -115,6 +115,13 @@ def get_parser(): """, ) + parser.add_argument( + "--modified-ctc-topo", + type=str2bool, + default=False, + help="True to use modified ctc topo.", + ) + return parser @@ -344,7 +351,9 @@ def compute_loss( token_ids = graph_compiler.texts_to_ids(texts) - decoding_graph = graph_compiler.compile(token_ids) + decoding_graph = graph_compiler.compile( + token_ids, modified=params.modified_ctc_topo + ) dense_fsa_vec = k2.DenseFsaVec( nnet_output, diff --git a/egs/librispeech/ASR/local/compile_hlg.py b/egs/librispeech/ASR/local/compile_hlg.py index 098d5d6a3..dc21e5e5b 100755 --- a/egs/librispeech/ASR/local/compile_hlg.py +++ b/egs/librispeech/ASR/local/compile_hlg.py @@ -27,6 +27,9 @@ This script takes as input lang_dir and generates HLG from - G, the LM, built from data/lm/G_3_gram.fst.txt The generated HLG is saved in $lang_dir/HLG.pt + +If the commandline argument --modified-ctc-topo is True, the generated +file is HLG_modified.pt """ import argparse import logging @@ -36,6 +39,7 @@ import k2 import torch from icefall.lexicon import Lexicon +from icefall.utils import str2bool def get_args(): @@ -46,15 +50,23 @@ def get_args(): help="""Input and output directory. """, ) + parser.add_argument( + "--modified-ctc-topo", + type=str2bool, + default=False, + help="True to use modified CTC topo", + ) return parser.parse_args() -def compile_HLG(lang_dir: str) -> k2.Fsa: +def compile_HLG(lang_dir: str, modified_ctc_topo: bool = False) -> k2.Fsa: """ Args: lang_dir: The language directory, e.g., data/lang_phone or data/lang_bpe_5000. + modified_ctc_topo: + True to use modified CTC topo. Return: An FSA representing HLG. @@ -62,17 +74,24 @@ def compile_HLG(lang_dir: str) -> k2.Fsa: lexicon = Lexicon(lang_dir) max_token_id = max(lexicon.tokens) logging.info(f"Building ctc_topo. max_token_id: {max_token_id}") - H = k2.ctc_topo(max_token_id) + if modified_ctc_topo: + logging.info("Using modified CTC topo") + else: + logging.info("Using standard CTC topo") + + H = k2.ctc_topo(max_token_id, modified=modified_ctc_topo) + logging.info(f"H.shape: {H.shape}, num_arcs: {H.num_arcs}") L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt")) if Path("data/lm/G_3_gram.pt").is_file(): - logging.info("Loading pre-compiled G_3_gram") + logging.info("Loading pre-compiled G_3_gram from data/lm/G_3_gram.pt") d = torch.load("data/lm/G_3_gram.pt") G = k2.Fsa.from_dict(d) else: logging.info("Loading G_3_gram.fst.txt") with open("data/lm/G_3_gram.fst.txt") as f: G = k2.Fsa.from_openfst(f.read(), acceptor=False) + logging.info("Saving pre-compiled data/lm/G_3_gram.pt") torch.save(G.as_dict(), "data/lm/G_3_gram.pt") first_token_disambig_id = lexicon.token_table["#0"] @@ -83,11 +102,13 @@ def compile_HLG(lang_dir: str) -> k2.Fsa: logging.info("Intersecting L and G") LG = k2.compose(L, G) - logging.info(f"LG shape: {LG.shape}") + logging.info(f"LG shape: {LG.shape}, num_arcs: {LG.num_arcs}") logging.info("Connecting LG") LG = k2.connect(LG) - logging.info(f"LG shape after k2.connect: {LG.shape}") + logging.info( + f"LG shape after k2.connect: {LG.shape}, num_arcs: {LG.num_arcs}" + ) logging.info(type(LG.aux_labels)) logging.info("Determinizing LG") @@ -106,7 +127,9 @@ def compile_HLG(lang_dir: str) -> k2.Fsa: LG.aux_labels.values[LG.aux_labels.values >= first_word_disambig_id] = 0 LG = k2.remove_epsilon(LG) - logging.info(f"LG shape after k2.remove_epsilon: {LG.shape}") + logging.info( + f"LG shape after k2.remove_epsilon: {LG.shape}, num_arcs: {LG.num_arcs}" + ) LG = k2.connect(LG) LG.aux_labels = LG.aux_labels.remove_values_eq(0) @@ -126,7 +149,7 @@ def compile_HLG(lang_dir: str) -> k2.Fsa: logging.info("Arc sorting LG") HLG = k2.arc_sort(HLG) - logging.info(f"HLG.shape: {HLG.shape}") + logging.info(f"HLG.shape: {HLG.shape}, num_arcs: {HLG.num_arcs}") return HLG @@ -134,16 +157,23 @@ def compile_HLG(lang_dir: str) -> k2.Fsa: def main(): args = get_args() lang_dir = Path(args.lang_dir) + logging.info(f"lang_dir: {args.lang_dir}") + logging.info(f"modified_ctc_topo: {args.modified_ctc_topo}") - if (lang_dir / "HLG.pt").is_file(): - logging.info(f"{lang_dir}/HLG.pt already exists - skipping") + if args.modified_ctc_topo: + filename = lang_dir / "HLG_modified.pt" + else: + filename = lang_dir / "HLG.pt" + + if filename.is_file(): + logging.info(f"{filename} already exists - skipping") return logging.info(f"Processing {lang_dir}") - HLG = compile_HLG(lang_dir) - logging.info(f"Saving HLG.pt to {lang_dir}") - torch.save(HLG.as_dict(), f"{lang_dir}/HLG.pt") + HLG = compile_HLG(lang_dir, modified_ctc_topo=args.modified_ctc_topo) + logging.info(f"Saving {filename}") + torch.save(HLG.as_dict(), str(filename)) if __name__ == "__main__": diff --git a/egs/librispeech/ASR/prepare.sh b/egs/librispeech/ASR/prepare.sh index 3b2678ec4..ee2f2192f 100755 --- a/egs/librispeech/ASR/prepare.sh +++ b/egs/librispeech/ASR/prepare.sh @@ -221,9 +221,11 @@ fi if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then log "Stage 9: Compile HLG" ./local/compile_hlg.py --lang-dir data/lang_phone + ./local/compile_hlg.py --lang-dir data/lang_phone --modified-ctc-topo True for vocab_size in ${vocab_sizes[@]}; do lang_dir=data/lang_bpe_${vocab_size} ./local/compile_hlg.py --lang-dir $lang_dir + ./local/compile_hlg.py --lang-dir $lang_dir --modified-ctc-topo true done fi