Support modified CTC topology

This commit is contained in:
Fangjun Kuang 2023-10-09 21:58:39 +08:00
parent fefffc02f6
commit 334cd7d03d
7 changed files with 135 additions and 12 deletions

View File

@ -8,6 +8,9 @@ tokens.txt, and words.txt and generates the following files:
- H.fst - H.fst
- HL.fst - HL.fst
If you also provide --ngram-G, then it also generates
- HLG.fst - HLG.fst
Note that saved files are in OpenFst binary format. Note that saved files are in OpenFst binary format.
@ -22,6 +25,12 @@ Or
./local/prepare_lang_fst.py \ ./local/prepare_lang_fst.py \
--lang-dir ./data/lang_bpe_500 --lang-dir ./data/lang_bpe_500
or
./local/prepare_lang_fst.py \
--lang-dir ./data/lang_bpe_500 \
--ngram-G ./data/lm/G_3_gram.fst.txt
""" """
import argparse import argparse
@ -35,6 +44,7 @@ from icefall.ctc import (
add_disambig_self_loops, add_disambig_self_loops,
add_one, add_one,
build_standard_ctc_topo, build_standard_ctc_topo,
build_ctc_topo_max_repeat0,
make_lexicon_fst_no_silence, make_lexicon_fst_no_silence,
make_lexicon_fst_with_silence, make_lexicon_fst_with_silence,
) )
@ -65,6 +75,17 @@ def get_args():
""", """,
) )
parser.add_argument(
"--max-num-repeats",
type=int,
default=-1,
help="""Allowed maximum number of repeats.
-1 means the standard CTC topology allowing infinite number of repeats.
0 means it does not allow repeats at all.
Any other value is currently not supported.
""",
)
return parser.parse_args() return parser.parse_args()
@ -171,14 +192,22 @@ def main():
lexicon = Lexicon(lang_dir) lexicon = Lexicon(lang_dir)
logging.info("Building standard CTC topology")
max_token_id = max(lexicon.tokens) max_token_id = max(lexicon.tokens)
if args.max_num_repeats == -1:
logging.info("Building standard CTC topology")
H = build_standard_ctc_topo(max_token_id=max_token_id) H = build_standard_ctc_topo(max_token_id=max_token_id)
suffix=''
elif args.max_num_repeats == 0:
logging.info("Building CTC topology allowing 0 repeat")
H = build_ctc_topo_max_repeat0(max_token_id=max_token_id)
suffix='_0'
else:
raise ValueError(f'Unsupported --max-num-repeats: {args.max_num_repeats}. Only -1 and 0 are valid')
# We need to add one to all tokens since we want to use ID 0 # We need to add one to all tokens since we want to use ID 0
# for epsilon # for epsilon
add_one(H, treat_ilabel_zero_specially=False, update_olabel=True) add_one(H, treat_ilabel_zero_specially=False, update_olabel=True)
H.write(f"{lang_dir}/H.fst") H.write(f"{lang_dir}/H{suffix}.fst")
logging.info("Building L") logging.info("Building L")
# Now for HL # Now for HL
@ -195,7 +224,7 @@ def main():
has_silence=args.has_silence, has_silence=args.has_silence,
lexicon=lexicon, lexicon=lexicon,
) )
HL.write(f"{lang_dir}/HL.fst") HL.write(f"{lang_dir}/HL{suffix}.fst")
if not args.ngram_G: if not args.ngram_G:
logging.info("Skip building HLG") logging.info("Skip building HLG")
@ -209,7 +238,7 @@ def main():
) )
HLG = build_HLG(H=H, L=L, G=G, has_silence=args.has_silence, lexicon=lexicon) HLG = build_HLG(H=H, L=L, G=G, has_silence=args.has_silence, lexicon=lexicon)
HLG.write(f"{lang_dir}/HLG.fst") HLG.write(f"{lang_dir}/HLG{suffix}.fst")
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -92,6 +92,7 @@ class OnnxModel:
session_opts = ort.SessionOptions() session_opts = ort.SessionOptions()
session_opts.inter_op_num_threads = 1 session_opts.inter_op_num_threads = 1
session_opts.intra_op_num_threads = 1 session_opts.intra_op_num_threads = 1
session_opts.log_severity_level = 3 # error level
self.session_opts = session_opts self.session_opts = session_opts

View File

@ -92,6 +92,7 @@ class OnnxModel:
session_opts = ort.SessionOptions() session_opts = ort.SessionOptions()
session_opts.inter_op_num_threads = 1 session_opts.inter_op_num_threads = 1
session_opts.intra_op_num_threads = 1 session_opts.intra_op_num_threads = 1
session_opts.log_severity_level = 3 # error level
self.session_opts = session_opts self.session_opts = session_opts

View File

@ -92,6 +92,7 @@ class OnnxModel:
session_opts = ort.SessionOptions() session_opts = ort.SessionOptions()
session_opts.inter_op_num_threads = 1 session_opts.inter_op_num_threads = 1
session_opts.intra_op_num_threads = 1 session_opts.intra_op_num_threads = 1
session_opts.log_severity_level = 3 # error level
self.session_opts = session_opts self.session_opts = session_opts

View File

@ -3,4 +3,9 @@ from .prepare_lang import (
make_lexicon_fst_no_silence, make_lexicon_fst_no_silence,
make_lexicon_fst_with_silence, make_lexicon_fst_with_silence,
) )
from .topo import add_disambig_self_loops, add_one, build_standard_ctc_topo from .topo import (
add_disambig_self_loops,
add_one,
build_ctc_topo_max_repeat0,
build_standard_ctc_topo,
)

View File

@ -11,7 +11,12 @@ from prepare_lang import (
make_lexicon_fst_no_silence, make_lexicon_fst_no_silence,
make_lexicon_fst_with_silence, make_lexicon_fst_with_silence,
) )
from topo import add_disambig_self_loops, add_one, build_standard_ctc_topo from topo import (
add_disambig_self_loops,
add_one,
build_standard_ctc_topo,
build_ctc_topo_max_repeat0,
)
def test_yesno(): def test_yesno():
@ -131,7 +136,30 @@ def test_librispeech():
print(sp.encode(["HELLOA", "WORLD"])) print(sp.encode(["HELLOA", "WORLD"]))
def test_build_ctc_topo_max_repeat0():
H = build_ctc_topo_max_repeat0(max_token_id=3)
isym = kaldifst.SymbolTable()
isym.add_symbol(symbol="<blk>", key=0)
isym.add_symbol(symbol="C", key=1)
isym.add_symbol(symbol="A", key=2)
isym.add_symbol(symbol="T", key=3)
osym = kaldifst.SymbolTable()
osym.add_symbol(symbol="<eps>", key=0)
osym.add_symbol(symbol="C", key=1)
osym.add_symbol(symbol="A", key=2)
osym.add_symbol(symbol="T", key=3)
H.input_symbols = isym
H.output_symbols = osym
fst_dot = kaldifst.draw(H, acceptor=False, portrait=True)
source = graphviz.Source(fst_dot)
source.render(outfile="ctc_topo_max_repeat0.pdf")
def main(): def main():
test_build_ctc_topo_max_repeat0()
test_yesno() test_yesno()
test_librispeech() test_librispeech()

View File

@ -3,15 +3,16 @@
import kaldifst import kaldifst
# Note the name contains `standard`; it means there will be non-standard
# topologies.
def build_standard_ctc_topo(max_token_id: int) -> kaldifst.StdVectorFst: def build_standard_ctc_topo(max_token_id: int) -> kaldifst.StdVectorFst:
"""Build a standard CTC topology. """Build a standard CTC topology.
Args: Args:
max_token_id:
Maximum valid token ID. We assume token IDs are contiguous Maximum valid token ID. We assume token IDs are contiguous
and starts from 0. In other words, the vocabulary size is and starts from 0. In other words, the vocabulary size is
``max_token_id + 1``. We assume the ID of the blank symbol is 0. ``max_token_id + 1``. We assume the ID of the blank symbol is 0.
Returns:
Return a transducer representing the standard CTC topology.
""" """
# Token ID starts from 0 and there are as many states as the # Token ID starts from 0 and there are as many states as the
# number of tokens. # number of tokens.
@ -54,6 +55,63 @@ def build_standard_ctc_topo(max_token_id: int) -> kaldifst.StdVectorFst:
return fst return fst
def build_ctc_topo_max_repeat0(max_token_id: int) -> kaldifst.StdVectorFst:
"""Build a modified CTC topology which does not allow any repeats.
We remove the self-loop of each state except the state for the blank .
Args:
max_token_id:
Maximum valid token ID. We assume token IDs are contiguous
and starts from 0. In other words, the vocabulary size is
``max_token_id + 1``. We assume the ID of the blank symbol is 0.
Returns:
Return a transducer representing the modified CTC topology.
"""
# Token ID starts from 0 and there are as many states as the
# number of tokens.
#
# Note that epsilon is not a token and the token with ID 0 in tokens.txt
# is not an epsilon. It means input label 0 of the resulting FST does
# not represent an epsilon.
#
# You can use the function `add_one()` to modify the input/output labels
# of the resulting FST
num_states = max_token_id + 1
# Step 1: Create as many states as the number of tokens.
# Each state is a final state
fst = kaldifst.StdVectorFst()
for i in range(num_states):
s = fst.add_state()
fst.set_final(state=s, weight=0)
# Step 2: Set state 0 as the start state.
# We assume the ID of the blank symbol is 0.
fst.start = 0
# Step 3: Build a fully connected graph.
for i in range(num_states):
for k in range(num_states):
if i == k and i != 0:
# Remove the self-loop for states of non-blanks
continue
fst.add_arc(
state=i,
arc=kaldifst.StdArc(
ilabel=k,
olabel=k,
weight=0,
nextstate=k,
),
)
# Please see ./test_ctc_topo.py if you want to know what the resulting
# FST looks like
return fst
def add_one( def add_one(
fst: kaldifst.StdVectorFst, fst: kaldifst.StdVectorFst,
treat_ilabel_zero_specially: bool, treat_ilabel_zero_specially: bool,