Support modified CTC topology

This commit is contained in:
Fangjun Kuang 2023-10-09 21:58:39 +08:00
parent fefffc02f6
commit 334cd7d03d
7 changed files with 135 additions and 12 deletions

View File

@ -8,6 +8,9 @@ tokens.txt, and words.txt and generates the following files:
- H.fst
- HL.fst
If you also provide --ngram-G, then it also generates
- HLG.fst
Note that saved files are in OpenFst binary format.
@ -22,6 +25,12 @@ Or
./local/prepare_lang_fst.py \
--lang-dir ./data/lang_bpe_500
or
./local/prepare_lang_fst.py \
--lang-dir ./data/lang_bpe_500 \
--ngram-G ./data/lm/G_3_gram.fst.txt
"""
import argparse
@ -35,6 +44,7 @@ from icefall.ctc import (
add_disambig_self_loops,
add_one,
build_standard_ctc_topo,
build_ctc_topo_max_repeat0,
make_lexicon_fst_no_silence,
make_lexicon_fst_with_silence,
)
@ -65,6 +75,17 @@ def get_args():
""",
)
parser.add_argument(
"--max-num-repeats",
type=int,
default=-1,
help="""Allowed maximum number of repeats.
-1 means the standard CTC topology allowing infinite number of repeats.
0 means it does not allow repeats at all.
Any other value is currently not supported.
""",
)
return parser.parse_args()
@ -171,14 +192,22 @@ def main():
lexicon = Lexicon(lang_dir)
logging.info("Building standard CTC topology")
max_token_id = max(lexicon.tokens)
H = build_standard_ctc_topo(max_token_id=max_token_id)
if args.max_num_repeats == -1:
logging.info("Building standard CTC topology")
H = build_standard_ctc_topo(max_token_id=max_token_id)
suffix=''
elif args.max_num_repeats == 0:
logging.info("Building CTC topology allowing 0 repeat")
H = build_ctc_topo_max_repeat0(max_token_id=max_token_id)
suffix='_0'
else:
raise ValueError(f'Unsupported --max-num-repeats: {args.max_num_repeats}. Only -1 and 0 are valid')
# We need to add one to all tokens since we want to use ID 0
# for epsilon
add_one(H, treat_ilabel_zero_specially=False, update_olabel=True)
H.write(f"{lang_dir}/H.fst")
H.write(f"{lang_dir}/H{suffix}.fst")
logging.info("Building L")
# Now for HL
@ -195,7 +224,7 @@ def main():
has_silence=args.has_silence,
lexicon=lexicon,
)
HL.write(f"{lang_dir}/HL.fst")
HL.write(f"{lang_dir}/HL{suffix}.fst")
if not args.ngram_G:
logging.info("Skip building HLG")
@ -209,7 +238,7 @@ def main():
)
HLG = build_HLG(H=H, L=L, G=G, has_silence=args.has_silence, lexicon=lexicon)
HLG.write(f"{lang_dir}/HLG.fst")
HLG.write(f"{lang_dir}/HLG{suffix}.fst")
if __name__ == "__main__":

View File

@ -92,6 +92,7 @@ class OnnxModel:
session_opts = ort.SessionOptions()
session_opts.inter_op_num_threads = 1
session_opts.intra_op_num_threads = 1
session_opts.log_severity_level = 3 # error level
self.session_opts = session_opts

View File

@ -92,6 +92,7 @@ class OnnxModel:
session_opts = ort.SessionOptions()
session_opts.inter_op_num_threads = 1
session_opts.intra_op_num_threads = 1
session_opts.log_severity_level = 3 # error level
self.session_opts = session_opts

View File

@ -92,6 +92,7 @@ class OnnxModel:
session_opts = ort.SessionOptions()
session_opts.inter_op_num_threads = 1
session_opts.intra_op_num_threads = 1
session_opts.log_severity_level = 3 # error level
self.session_opts = session_opts

View File

@ -3,4 +3,9 @@ from .prepare_lang import (
make_lexicon_fst_no_silence,
make_lexicon_fst_with_silence,
)
from .topo import add_disambig_self_loops, add_one, build_standard_ctc_topo
from .topo import (
add_disambig_self_loops,
add_one,
build_ctc_topo_max_repeat0,
build_standard_ctc_topo,
)

View File

@ -11,7 +11,12 @@ from prepare_lang import (
make_lexicon_fst_no_silence,
make_lexicon_fst_with_silence,
)
from topo import add_disambig_self_loops, add_one, build_standard_ctc_topo
from topo import (
add_disambig_self_loops,
add_one,
build_standard_ctc_topo,
build_ctc_topo_max_repeat0,
)
def test_yesno():
@ -131,7 +136,30 @@ def test_librispeech():
print(sp.encode(["HELLOA", "WORLD"]))
def test_build_ctc_topo_max_repeat0():
H = build_ctc_topo_max_repeat0(max_token_id=3)
isym = kaldifst.SymbolTable()
isym.add_symbol(symbol="<blk>", key=0)
isym.add_symbol(symbol="C", key=1)
isym.add_symbol(symbol="A", key=2)
isym.add_symbol(symbol="T", key=3)
osym = kaldifst.SymbolTable()
osym.add_symbol(symbol="<eps>", key=0)
osym.add_symbol(symbol="C", key=1)
osym.add_symbol(symbol="A", key=2)
osym.add_symbol(symbol="T", key=3)
H.input_symbols = isym
H.output_symbols = osym
fst_dot = kaldifst.draw(H, acceptor=False, portrait=True)
source = graphviz.Source(fst_dot)
source.render(outfile="ctc_topo_max_repeat0.pdf")
def main():
test_build_ctc_topo_max_repeat0()
test_yesno()
test_librispeech()

View File

@ -3,15 +3,16 @@
import kaldifst
# Note the name contains `standard`; it means there will be non-standard
# topologies.
def build_standard_ctc_topo(max_token_id: int) -> kaldifst.StdVectorFst:
"""Build a standard CTC topology.
Args:
Maximum valid token ID. We assume token IDs are contiguous
and starts from 0. In other words, the vocabulary size is
``max_token_id + 1``. We assume the ID of the blank symbol is 0.
max_token_id:
Maximum valid token ID. We assume token IDs are contiguous
and starts from 0. In other words, the vocabulary size is
``max_token_id + 1``. We assume the ID of the blank symbol is 0.
Returns:
Return a transducer representing the standard CTC topology.
"""
# Token ID starts from 0 and there are as many states as the
# number of tokens.
@ -54,6 +55,63 @@ def build_standard_ctc_topo(max_token_id: int) -> kaldifst.StdVectorFst:
return fst
def build_ctc_topo_max_repeat0(max_token_id: int) -> kaldifst.StdVectorFst:
"""Build a modified CTC topology which does not allow any repeats.
We remove the self-loop of each state except the state for the blank .
Args:
max_token_id:
Maximum valid token ID. We assume token IDs are contiguous
and starts from 0. In other words, the vocabulary size is
``max_token_id + 1``. We assume the ID of the blank symbol is 0.
Returns:
Return a transducer representing the modified CTC topology.
"""
# Token ID starts from 0 and there are as many states as the
# number of tokens.
#
# Note that epsilon is not a token and the token with ID 0 in tokens.txt
# is not an epsilon. It means input label 0 of the resulting FST does
# not represent an epsilon.
#
# You can use the function `add_one()` to modify the input/output labels
# of the resulting FST
num_states = max_token_id + 1
# Step 1: Create as many states as the number of tokens.
# Each state is a final state
fst = kaldifst.StdVectorFst()
for i in range(num_states):
s = fst.add_state()
fst.set_final(state=s, weight=0)
# Step 2: Set state 0 as the start state.
# We assume the ID of the blank symbol is 0.
fst.start = 0
# Step 3: Build a fully connected graph.
for i in range(num_states):
for k in range(num_states):
if i == k and i != 0:
# Remove the self-loop for states of non-blanks
continue
fst.add_arc(
state=i,
arc=kaldifst.StdArc(
ilabel=k,
olabel=k,
weight=0,
nextstate=k,
),
)
# Please see ./test_ctc_topo.py if you want to know what the resulting
# FST looks like
return fst
def add_one(
fst: kaldifst.StdVectorFst,
treat_ilabel_zero_specially: bool,