From 334cd7d03d98e6e66cda5a72371036f4d94f81ed Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Mon, 9 Oct 2023 21:58:39 +0800 Subject: [PATCH] Support modified CTC topology --- egs/librispeech/ASR/local/prepare_lang_fst.py | 39 +++++++++-- .../ASR/zipformer/onnx_pretrained_ctc_H.py | 1 + .../ASR/zipformer/onnx_pretrained_ctc_HL.py | 1 + .../ASR/zipformer/onnx_pretrained_ctc_HLG.py | 1 + icefall/ctc/__init__.py | 7 +- icefall/ctc/test_ctc_topo.py | 30 +++++++- icefall/ctc/topo.py | 68 +++++++++++++++++-- 7 files changed, 135 insertions(+), 12 deletions(-) diff --git a/egs/librispeech/ASR/local/prepare_lang_fst.py b/egs/librispeech/ASR/local/prepare_lang_fst.py index fb1e7f9c0..15cac24e5 100755 --- a/egs/librispeech/ASR/local/prepare_lang_fst.py +++ b/egs/librispeech/ASR/local/prepare_lang_fst.py @@ -8,6 +8,9 @@ tokens.txt, and words.txt and generates the following files: - H.fst - HL.fst + +If you also provide --ngram-G, then it also generates + - HLG.fst Note that saved files are in OpenFst binary format. @@ -22,6 +25,12 @@ Or ./local/prepare_lang_fst.py \ --lang-dir ./data/lang_bpe_500 + +or + +./local/prepare_lang_fst.py \ + --lang-dir ./data/lang_bpe_500 \ + --ngram-G ./data/lm/G_3_gram.fst.txt """ import argparse @@ -35,6 +44,7 @@ from icefall.ctc import ( add_disambig_self_loops, add_one, build_standard_ctc_topo, + build_ctc_topo_max_repeat0, make_lexicon_fst_no_silence, make_lexicon_fst_with_silence, ) @@ -65,6 +75,17 @@ def get_args(): """, ) + parser.add_argument( + "--max-num-repeats", + type=int, + default=-1, + help="""Allowed maximum number of repeats. + -1 means the standard CTC topology allowing infinite number of repeats. + 0 means it does not allow repeats at all. + Any other value is currently not supported. + """, + ) + return parser.parse_args() @@ -171,14 +192,22 @@ def main(): lexicon = Lexicon(lang_dir) - logging.info("Building standard CTC topology") max_token_id = max(lexicon.tokens) - H = build_standard_ctc_topo(max_token_id=max_token_id) + if args.max_num_repeats == -1: + logging.info("Building standard CTC topology") + H = build_standard_ctc_topo(max_token_id=max_token_id) + suffix='' + elif args.max_num_repeats == 0: + logging.info("Building CTC topology allowing 0 repeat") + H = build_ctc_topo_max_repeat0(max_token_id=max_token_id) + suffix='_0' + else: + raise ValueError(f'Unsupported --max-num-repeats: {args.max_num_repeats}. Only -1 and 0 are valid') # We need to add one to all tokens since we want to use ID 0 # for epsilon add_one(H, treat_ilabel_zero_specially=False, update_olabel=True) - H.write(f"{lang_dir}/H.fst") + H.write(f"{lang_dir}/H{suffix}.fst") logging.info("Building L") # Now for HL @@ -195,7 +224,7 @@ def main(): has_silence=args.has_silence, lexicon=lexicon, ) - HL.write(f"{lang_dir}/HL.fst") + HL.write(f"{lang_dir}/HL{suffix}.fst") if not args.ngram_G: logging.info("Skip building HLG") @@ -209,7 +238,7 @@ def main(): ) HLG = build_HLG(H=H, L=L, G=G, has_silence=args.has_silence, lexicon=lexicon) - HLG.write(f"{lang_dir}/HLG.fst") + HLG.write(f"{lang_dir}/HLG{suffix}.fst") if __name__ == "__main__": diff --git a/egs/librispeech/ASR/zipformer/onnx_pretrained_ctc_H.py b/egs/librispeech/ASR/zipformer/onnx_pretrained_ctc_H.py index 683a7dc20..dd8d46705 100755 --- a/egs/librispeech/ASR/zipformer/onnx_pretrained_ctc_H.py +++ b/egs/librispeech/ASR/zipformer/onnx_pretrained_ctc_H.py @@ -92,6 +92,7 @@ class OnnxModel: session_opts = ort.SessionOptions() session_opts.inter_op_num_threads = 1 session_opts.intra_op_num_threads = 1 + session_opts.log_severity_level = 3 # error level self.session_opts = session_opts diff --git a/egs/librispeech/ASR/zipformer/onnx_pretrained_ctc_HL.py b/egs/librispeech/ASR/zipformer/onnx_pretrained_ctc_HL.py index 0b94bfa65..cf90b388c 100755 --- a/egs/librispeech/ASR/zipformer/onnx_pretrained_ctc_HL.py +++ b/egs/librispeech/ASR/zipformer/onnx_pretrained_ctc_HL.py @@ -92,6 +92,7 @@ class OnnxModel: session_opts = ort.SessionOptions() session_opts.inter_op_num_threads = 1 session_opts.intra_op_num_threads = 1 + session_opts.log_severity_level = 3 # error level self.session_opts = session_opts diff --git a/egs/librispeech/ASR/zipformer/onnx_pretrained_ctc_HLG.py b/egs/librispeech/ASR/zipformer/onnx_pretrained_ctc_HLG.py index 93569142a..2b781eac8 100755 --- a/egs/librispeech/ASR/zipformer/onnx_pretrained_ctc_HLG.py +++ b/egs/librispeech/ASR/zipformer/onnx_pretrained_ctc_HLG.py @@ -92,6 +92,7 @@ class OnnxModel: session_opts = ort.SessionOptions() session_opts.inter_op_num_threads = 1 session_opts.intra_op_num_threads = 1 + session_opts.log_severity_level = 3 # error level self.session_opts = session_opts diff --git a/icefall/ctc/__init__.py b/icefall/ctc/__init__.py index b546b31af..31b7565cb 100644 --- a/icefall/ctc/__init__.py +++ b/icefall/ctc/__init__.py @@ -3,4 +3,9 @@ from .prepare_lang import ( make_lexicon_fst_no_silence, make_lexicon_fst_with_silence, ) -from .topo import add_disambig_self_loops, add_one, build_standard_ctc_topo +from .topo import ( + add_disambig_self_loops, + add_one, + build_ctc_topo_max_repeat0, + build_standard_ctc_topo, +) diff --git a/icefall/ctc/test_ctc_topo.py b/icefall/ctc/test_ctc_topo.py index 4d4667209..e87cd9a34 100755 --- a/icefall/ctc/test_ctc_topo.py +++ b/icefall/ctc/test_ctc_topo.py @@ -11,7 +11,12 @@ from prepare_lang import ( make_lexicon_fst_no_silence, make_lexicon_fst_with_silence, ) -from topo import add_disambig_self_loops, add_one, build_standard_ctc_topo +from topo import ( + add_disambig_self_loops, + add_one, + build_standard_ctc_topo, + build_ctc_topo_max_repeat0, +) def test_yesno(): @@ -131,7 +136,30 @@ def test_librispeech(): print(sp.encode(["HELLOA", "WORLD"])) +def test_build_ctc_topo_max_repeat0(): + H = build_ctc_topo_max_repeat0(max_token_id=3) + isym = kaldifst.SymbolTable() + isym.add_symbol(symbol="", key=0) + isym.add_symbol(symbol="C", key=1) + isym.add_symbol(symbol="A", key=2) + isym.add_symbol(symbol="T", key=3) + + osym = kaldifst.SymbolTable() + osym.add_symbol(symbol="", key=0) + osym.add_symbol(symbol="C", key=1) + osym.add_symbol(symbol="A", key=2) + osym.add_symbol(symbol="T", key=3) + + H.input_symbols = isym + H.output_symbols = osym + + fst_dot = kaldifst.draw(H, acceptor=False, portrait=True) + source = graphviz.Source(fst_dot) + source.render(outfile="ctc_topo_max_repeat0.pdf") + + def main(): + test_build_ctc_topo_max_repeat0() test_yesno() test_librispeech() diff --git a/icefall/ctc/topo.py b/icefall/ctc/topo.py index 6a96dd038..b9cf56a4c 100644 --- a/icefall/ctc/topo.py +++ b/icefall/ctc/topo.py @@ -3,15 +3,16 @@ import kaldifst -# Note the name contains `standard`; it means there will be non-standard -# topologies. def build_standard_ctc_topo(max_token_id: int) -> kaldifst.StdVectorFst: """Build a standard CTC topology. Args: - Maximum valid token ID. We assume token IDs are contiguous - and starts from 0. In other words, the vocabulary size is - ``max_token_id + 1``. We assume the ID of the blank symbol is 0. + max_token_id: + Maximum valid token ID. We assume token IDs are contiguous + and starts from 0. In other words, the vocabulary size is + ``max_token_id + 1``. We assume the ID of the blank symbol is 0. + Returns: + Return a transducer representing the standard CTC topology. """ # Token ID starts from 0 and there are as many states as the # number of tokens. @@ -54,6 +55,63 @@ def build_standard_ctc_topo(max_token_id: int) -> kaldifst.StdVectorFst: return fst +def build_ctc_topo_max_repeat0(max_token_id: int) -> kaldifst.StdVectorFst: + """Build a modified CTC topology which does not allow any repeats. + + We remove the self-loop of each state except the state for the blank . + + Args: + max_token_id: + Maximum valid token ID. We assume token IDs are contiguous + and starts from 0. In other words, the vocabulary size is + ``max_token_id + 1``. We assume the ID of the blank symbol is 0. + Returns: + Return a transducer representing the modified CTC topology. + """ + # Token ID starts from 0 and there are as many states as the + # number of tokens. + # + # Note that epsilon is not a token and the token with ID 0 in tokens.txt + # is not an epsilon. It means input label 0 of the resulting FST does + # not represent an epsilon. + # + # You can use the function `add_one()` to modify the input/output labels + # of the resulting FST + + num_states = max_token_id + 1 + + # Step 1: Create as many states as the number of tokens. + # Each state is a final state + fst = kaldifst.StdVectorFst() + for i in range(num_states): + s = fst.add_state() + fst.set_final(state=s, weight=0) + + # Step 2: Set state 0 as the start state. + # We assume the ID of the blank symbol is 0. + fst.start = 0 + + # Step 3: Build a fully connected graph. + for i in range(num_states): + for k in range(num_states): + if i == k and i != 0: + # Remove the self-loop for states of non-blanks + continue + fst.add_arc( + state=i, + arc=kaldifst.StdArc( + ilabel=k, + olabel=k, + weight=0, + nextstate=k, + ), + ) + # Please see ./test_ctc_topo.py if you want to know what the resulting + # FST looks like + + return fst + + def add_one( fst: kaldifst.StdVectorFst, treat_ilabel_zero_specially: bool,