mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-07 16:14:17 +00:00
Support modified CTC topology
This commit is contained in:
parent
fefffc02f6
commit
334cd7d03d
@ -8,6 +8,9 @@ tokens.txt, and words.txt and generates the following files:
|
||||
|
||||
- H.fst
|
||||
- HL.fst
|
||||
|
||||
If you also provide --ngram-G, then it also generates
|
||||
|
||||
- HLG.fst
|
||||
|
||||
Note that saved files are in OpenFst binary format.
|
||||
@ -22,6 +25,12 @@ Or
|
||||
|
||||
./local/prepare_lang_fst.py \
|
||||
--lang-dir ./data/lang_bpe_500
|
||||
|
||||
or
|
||||
|
||||
./local/prepare_lang_fst.py \
|
||||
--lang-dir ./data/lang_bpe_500 \
|
||||
--ngram-G ./data/lm/G_3_gram.fst.txt
|
||||
"""
|
||||
|
||||
import argparse
|
||||
@ -35,6 +44,7 @@ from icefall.ctc import (
|
||||
add_disambig_self_loops,
|
||||
add_one,
|
||||
build_standard_ctc_topo,
|
||||
build_ctc_topo_max_repeat0,
|
||||
make_lexicon_fst_no_silence,
|
||||
make_lexicon_fst_with_silence,
|
||||
)
|
||||
@ -65,6 +75,17 @@ def get_args():
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--max-num-repeats",
|
||||
type=int,
|
||||
default=-1,
|
||||
help="""Allowed maximum number of repeats.
|
||||
-1 means the standard CTC topology allowing infinite number of repeats.
|
||||
0 means it does not allow repeats at all.
|
||||
Any other value is currently not supported.
|
||||
""",
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
@ -171,14 +192,22 @@ def main():
|
||||
|
||||
lexicon = Lexicon(lang_dir)
|
||||
|
||||
logging.info("Building standard CTC topology")
|
||||
max_token_id = max(lexicon.tokens)
|
||||
H = build_standard_ctc_topo(max_token_id=max_token_id)
|
||||
if args.max_num_repeats == -1:
|
||||
logging.info("Building standard CTC topology")
|
||||
H = build_standard_ctc_topo(max_token_id=max_token_id)
|
||||
suffix=''
|
||||
elif args.max_num_repeats == 0:
|
||||
logging.info("Building CTC topology allowing 0 repeat")
|
||||
H = build_ctc_topo_max_repeat0(max_token_id=max_token_id)
|
||||
suffix='_0'
|
||||
else:
|
||||
raise ValueError(f'Unsupported --max-num-repeats: {args.max_num_repeats}. Only -1 and 0 are valid')
|
||||
|
||||
# We need to add one to all tokens since we want to use ID 0
|
||||
# for epsilon
|
||||
add_one(H, treat_ilabel_zero_specially=False, update_olabel=True)
|
||||
H.write(f"{lang_dir}/H.fst")
|
||||
H.write(f"{lang_dir}/H{suffix}.fst")
|
||||
|
||||
logging.info("Building L")
|
||||
# Now for HL
|
||||
@ -195,7 +224,7 @@ def main():
|
||||
has_silence=args.has_silence,
|
||||
lexicon=lexicon,
|
||||
)
|
||||
HL.write(f"{lang_dir}/HL.fst")
|
||||
HL.write(f"{lang_dir}/HL{suffix}.fst")
|
||||
|
||||
if not args.ngram_G:
|
||||
logging.info("Skip building HLG")
|
||||
@ -209,7 +238,7 @@ def main():
|
||||
)
|
||||
|
||||
HLG = build_HLG(H=H, L=L, G=G, has_silence=args.has_silence, lexicon=lexicon)
|
||||
HLG.write(f"{lang_dir}/HLG.fst")
|
||||
HLG.write(f"{lang_dir}/HLG{suffix}.fst")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -92,6 +92,7 @@ class OnnxModel:
|
||||
session_opts = ort.SessionOptions()
|
||||
session_opts.inter_op_num_threads = 1
|
||||
session_opts.intra_op_num_threads = 1
|
||||
session_opts.log_severity_level = 3 # error level
|
||||
|
||||
self.session_opts = session_opts
|
||||
|
||||
|
@ -92,6 +92,7 @@ class OnnxModel:
|
||||
session_opts = ort.SessionOptions()
|
||||
session_opts.inter_op_num_threads = 1
|
||||
session_opts.intra_op_num_threads = 1
|
||||
session_opts.log_severity_level = 3 # error level
|
||||
|
||||
self.session_opts = session_opts
|
||||
|
||||
|
@ -92,6 +92,7 @@ class OnnxModel:
|
||||
session_opts = ort.SessionOptions()
|
||||
session_opts.inter_op_num_threads = 1
|
||||
session_opts.intra_op_num_threads = 1
|
||||
session_opts.log_severity_level = 3 # error level
|
||||
|
||||
self.session_opts = session_opts
|
||||
|
||||
|
@ -3,4 +3,9 @@ from .prepare_lang import (
|
||||
make_lexicon_fst_no_silence,
|
||||
make_lexicon_fst_with_silence,
|
||||
)
|
||||
from .topo import add_disambig_self_loops, add_one, build_standard_ctc_topo
|
||||
from .topo import (
|
||||
add_disambig_self_loops,
|
||||
add_one,
|
||||
build_ctc_topo_max_repeat0,
|
||||
build_standard_ctc_topo,
|
||||
)
|
||||
|
@ -11,7 +11,12 @@ from prepare_lang import (
|
||||
make_lexicon_fst_no_silence,
|
||||
make_lexicon_fst_with_silence,
|
||||
)
|
||||
from topo import add_disambig_self_loops, add_one, build_standard_ctc_topo
|
||||
from topo import (
|
||||
add_disambig_self_loops,
|
||||
add_one,
|
||||
build_standard_ctc_topo,
|
||||
build_ctc_topo_max_repeat0,
|
||||
)
|
||||
|
||||
|
||||
def test_yesno():
|
||||
@ -131,7 +136,30 @@ def test_librispeech():
|
||||
print(sp.encode(["HELLOA", "WORLD"]))
|
||||
|
||||
|
||||
def test_build_ctc_topo_max_repeat0():
|
||||
H = build_ctc_topo_max_repeat0(max_token_id=3)
|
||||
isym = kaldifst.SymbolTable()
|
||||
isym.add_symbol(symbol="<blk>", key=0)
|
||||
isym.add_symbol(symbol="C", key=1)
|
||||
isym.add_symbol(symbol="A", key=2)
|
||||
isym.add_symbol(symbol="T", key=3)
|
||||
|
||||
osym = kaldifst.SymbolTable()
|
||||
osym.add_symbol(symbol="<eps>", key=0)
|
||||
osym.add_symbol(symbol="C", key=1)
|
||||
osym.add_symbol(symbol="A", key=2)
|
||||
osym.add_symbol(symbol="T", key=3)
|
||||
|
||||
H.input_symbols = isym
|
||||
H.output_symbols = osym
|
||||
|
||||
fst_dot = kaldifst.draw(H, acceptor=False, portrait=True)
|
||||
source = graphviz.Source(fst_dot)
|
||||
source.render(outfile="ctc_topo_max_repeat0.pdf")
|
||||
|
||||
|
||||
def main():
|
||||
test_build_ctc_topo_max_repeat0()
|
||||
test_yesno()
|
||||
test_librispeech()
|
||||
|
||||
|
@ -3,15 +3,16 @@
|
||||
import kaldifst
|
||||
|
||||
|
||||
# Note the name contains `standard`; it means there will be non-standard
|
||||
# topologies.
|
||||
def build_standard_ctc_topo(max_token_id: int) -> kaldifst.StdVectorFst:
|
||||
"""Build a standard CTC topology.
|
||||
|
||||
Args:
|
||||
Maximum valid token ID. We assume token IDs are contiguous
|
||||
and starts from 0. In other words, the vocabulary size is
|
||||
``max_token_id + 1``. We assume the ID of the blank symbol is 0.
|
||||
max_token_id:
|
||||
Maximum valid token ID. We assume token IDs are contiguous
|
||||
and starts from 0. In other words, the vocabulary size is
|
||||
``max_token_id + 1``. We assume the ID of the blank symbol is 0.
|
||||
Returns:
|
||||
Return a transducer representing the standard CTC topology.
|
||||
"""
|
||||
# Token ID starts from 0 and there are as many states as the
|
||||
# number of tokens.
|
||||
@ -54,6 +55,63 @@ def build_standard_ctc_topo(max_token_id: int) -> kaldifst.StdVectorFst:
|
||||
return fst
|
||||
|
||||
|
||||
def build_ctc_topo_max_repeat0(max_token_id: int) -> kaldifst.StdVectorFst:
|
||||
"""Build a modified CTC topology which does not allow any repeats.
|
||||
|
||||
We remove the self-loop of each state except the state for the blank .
|
||||
|
||||
Args:
|
||||
max_token_id:
|
||||
Maximum valid token ID. We assume token IDs are contiguous
|
||||
and starts from 0. In other words, the vocabulary size is
|
||||
``max_token_id + 1``. We assume the ID of the blank symbol is 0.
|
||||
Returns:
|
||||
Return a transducer representing the modified CTC topology.
|
||||
"""
|
||||
# Token ID starts from 0 and there are as many states as the
|
||||
# number of tokens.
|
||||
#
|
||||
# Note that epsilon is not a token and the token with ID 0 in tokens.txt
|
||||
# is not an epsilon. It means input label 0 of the resulting FST does
|
||||
# not represent an epsilon.
|
||||
#
|
||||
# You can use the function `add_one()` to modify the input/output labels
|
||||
# of the resulting FST
|
||||
|
||||
num_states = max_token_id + 1
|
||||
|
||||
# Step 1: Create as many states as the number of tokens.
|
||||
# Each state is a final state
|
||||
fst = kaldifst.StdVectorFst()
|
||||
for i in range(num_states):
|
||||
s = fst.add_state()
|
||||
fst.set_final(state=s, weight=0)
|
||||
|
||||
# Step 2: Set state 0 as the start state.
|
||||
# We assume the ID of the blank symbol is 0.
|
||||
fst.start = 0
|
||||
|
||||
# Step 3: Build a fully connected graph.
|
||||
for i in range(num_states):
|
||||
for k in range(num_states):
|
||||
if i == k and i != 0:
|
||||
# Remove the self-loop for states of non-blanks
|
||||
continue
|
||||
fst.add_arc(
|
||||
state=i,
|
||||
arc=kaldifst.StdArc(
|
||||
ilabel=k,
|
||||
olabel=k,
|
||||
weight=0,
|
||||
nextstate=k,
|
||||
),
|
||||
)
|
||||
# Please see ./test_ctc_topo.py if you want to know what the resulting
|
||||
# FST looks like
|
||||
|
||||
return fst
|
||||
|
||||
|
||||
def add_one(
|
||||
fst: kaldifst.StdVectorFst,
|
||||
treat_ilabel_zero_specially: bool,
|
||||
|
Loading…
x
Reference in New Issue
Block a user