Support CTC decoding on CPU using OpenFst and kaldi decoders. (#1244)

This commit is contained in:
Fangjun Kuang 2023-09-26 16:36:19 +08:00 committed by GitHub
parent 1b565dd251
commit 2318c3fbd0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
25 changed files with 1783 additions and 4 deletions

View File

@ -24,6 +24,7 @@ exclude =
**/data/**, **/data/**,
icefall/shared/make_kn_lm.py, icefall/shared/make_kn_lm.py,
icefall/__init__.py icefall/__init__.py
icefall/ctc/__init__.py
ignore = ignore =
# E203 white space before ":" # E203 white space before ":"

View File

@ -44,3 +44,46 @@ log "HLG decoding"
$repo/test_wavs/1089-134686-0001.flac \ $repo/test_wavs/1089-134686-0001.flac \
$repo/test_wavs/1221-135766-0001.flac \ $repo/test_wavs/1221-135766-0001.flac \
$repo/test_wavs/1221-135766-0002.flac $repo/test_wavs/1221-135766-0002.flac
log "CTC decoding on CPU with kaldi decoders using OpenFst"
log "Exporting model with torchscript"
pushd $repo/exp
ln -s pretrained.pt epoch-99.pt
popd
./conformer_ctc/export.py \
--epoch 99 \
--avg 1 \
--exp-dir $repo/exp \
--tokens $repo/data/lang_bpe_500/tokens.txt \
--jit 1
ls -lh $repo/exp
log "Generating H.fst, HL.fst"
./local/prepare_lang_fst.py --lang-dir $repo/data/lang_bpe_500
ls -lh $repo/data/lang_bpe_500
log "Decoding with H on CPU with OpenFst"
./conformer_ctc/jit_pretrained_decode_with_H.py \
--nn-model $repo/exp/cpu_jit.pt \
--H $repo/data/lang_bpe_500/H.fst \
--tokens $repo/data/lang_bpe_500/tokens.txt \
$repo/test_wavs/1089-134686-0001.flac \
$repo/test_wavs/1221-135766-0001.flac \
$repo/test_wavs/1221-135766-0002.flac
log "Decoding with HL on CPU with OpenFst"
./conformer_ctc/jit_pretrained_decode_with_HL.py \
--nn-model $repo/exp/cpu_jit.pt \
--HL $repo/data/lang_bpe_500/HL.fst \
--words $repo/data/lang_bpe_500/words.txt \
$repo/test_wavs/1089-134686-0001.flac \
$repo/test_wavs/1221-135766-0001.flac \
$repo/test_wavs/1221-135766-0002.flac

View File

@ -29,7 +29,7 @@ concurrency:
jobs: jobs:
run_pre_trained_conformer_ctc: run_pre_trained_conformer_ctc:
if: github.event.label.name == 'ready' || github.event_name == 'push' if: github.event.label.name == 'ready' || github.event_name == 'push' || github.event.label.name == 'ctc'
runs-on: ${{ matrix.os }} runs-on: ${{ matrix.os }}
strategy: strategy:
matrix: matrix:

View File

@ -140,9 +140,46 @@ jobs:
download/waves_yesno/0_0_0_1_0_0_0_1.wav \ download/waves_yesno/0_0_0_1_0_0_0_1.wav \
download/waves_yesno/0_0_1_0_0_0_1_0.wav download/waves_yesno/0_0_1_0_0_0_1_0.wav
- name: Test decoding with H
shell: bash
working-directory: ${{github.workspace}}
run: |
export PYTHONPATH=$PWD:$PYTHONPATH
echo $PYTHONPATH
cd egs/yesno/ASR
python3 ./tdnn/export.py --epoch 14 --avg 2 --jit 1
python3 ./tdnn/jit_pretrained_decode_with_H.py \
--nn-model ./tdnn/exp/cpu_jit.pt \
--H ./data/lang_phone/H.fst \
--tokens ./data/lang_phone/tokens.txt \
./download/waves_yesno/0_0_0_1_0_0_0_1.wav \
./download/waves_yesno/0_0_1_0_0_0_1_0.wav \
./download/waves_yesno/0_0_1_0_0_1_1_1.wav
- name: Test decoding with HL
shell: bash
working-directory: ${{github.workspace}}
run: |
export PYTHONPATH=$PWD:$PYTHONPATH
echo $PYTHONPATH
cd egs/yesno/ASR
python3 ./tdnn/export.py --epoch 14 --avg 2 --jit 1
python3 ./tdnn/jit_pretrained_decode_with_HL.py \
--nn-model ./tdnn/exp/cpu_jit.pt \
--HL ./data/lang_phone/HL.fst \
--words ./data/lang_phone/words.txt \
./download/waves_yesno/0_0_0_1_0_0_0_1.wav \
./download/waves_yesno/0_0_1_0_0_0_1_0.wav \
./download/waves_yesno/0_0_1_0_0_1_1_1.wav
- name: Show generated files - name: Show generated files
shell: bash shell: bash
working-directory: ${{github.workspace}} working-directory: ${{github.workspace}}
run: | run: |
cd egs/yesno/ASR cd egs/yesno/ASR
ls -lh tdnn/exp ls -lh tdnn/exp
ls -lh data/lang_phone

2
.gitignore vendored
View File

@ -34,3 +34,5 @@ node_modules
*.param *.param
*.bin *.bin
.DS_Store .DS_Store
*.fst
*.arpa

View File

@ -1,3 +1,5 @@
.. _icefall_export_to_ncnn:
Export to ncnn Export to ncnn
============== ==============

View File

@ -0,0 +1,235 @@
#!/usr/bin/env python3
# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang)
"""
This file shows how to use a torchscript model for decoding with H
on CPU using OpenFST and decoders from kaldi.
Usage:
./conformer_ctc/jit_pretrained_decode_with_H.py \
--nn-model ./conformer_ctc/exp/cpu_jit.pt \
--H ./data/lang_bpe_500/H.fst \
--tokens ./data/lang_bpe_500/tokens.txt \
./download/LibriSpeech/test-clean/1089/134686/1089-134686-0002.flac \
./download/LibriSpeech/test-clean/1221/135766/1221-135766-0001.flac
Note that to generate ./conformer_ctc/exp/cpu_jit.pt,
you can use ./export.py --jit 1
"""
import argparse
import logging
import math
from typing import Dict, List
import kaldi_hmm_gmm
import kaldifeat
import kaldifst
import torch
import torchaudio
from kaldi_hmm_gmm import DecodableCtc, FasterDecoder, FasterDecoderOptions
from torch.nn.utils.rnn import pad_sequence
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--nn-model",
type=str,
required=True,
help="""Path to the torchscript model.
You can use ./conformer_ctc/export.py --jit 1
to obtain it
""",
)
parser.add_argument(
"--tokens",
type=str,
required=True,
help="Path to tokens.txt",
)
parser.add_argument("--H", type=str, required=True, help="Path to H.fst")
parser.add_argument(
"sound_files",
type=str,
nargs="+",
help="The input sound file(s) to transcribe. "
"Supported formats are those supported by torchaudio.load(). "
"For example, wav and flac are supported. ",
)
return parser
def read_tokens(tokens_txt: str) -> Dict[int, str]:
id2token = dict()
with open(tokens_txt, encoding="utf-8") as f:
for line in f:
token, idx = line.strip().split()
id2token[int(idx)] = token
return id2token
def read_sound_files(
filenames: List[str], expected_sample_rate: float
) -> List[torch.Tensor]:
"""Read a list of sound files into a list 1-D float32 torch tensors.
Args:
filenames:
A list of sound filenames.
expected_sample_rate:
The expected sample rate of the sound files.
Returns:
Return a list of 1-D float32 torch tensors.
"""
ans = []
for f in filenames:
wave, sample_rate = torchaudio.load(f)
if sample_rate != expected_sample_rate:
wave = torchaudio.functional.resample(
wave,
orig_freq=sample_rate,
new_freq=expected_sample_rate,
)
# We use only the first channel
ans.append(wave[0].contiguous())
return ans
def decode(
filename: str,
nnet_output: torch.Tensor,
H: kaldifst,
id2token: Dict[int, str],
) -> List[str]:
"""
Args:
filename:
Path to the filename for decoding. Used for debugging.
nnet_output:
A 2-D float32 tensor of shape (num_frames, vocab_size). It
contains output from log_softmax.
H:
The H graph.
id2token:
A map mapping token ID to token string.
Returns:
Return a list of decoded tokens.
"""
logging.info(f"{filename}, {nnet_output.shape}")
decodable = DecodableCtc(nnet_output.cpu())
decoder_opts = FasterDecoderOptions(max_active=3000)
decoder = FasterDecoder(H, decoder_opts)
decoder.decode(decodable)
if not decoder.reached_final():
print(f"failed to decode {filename}")
return [""]
ok, best_path = decoder.get_best_path()
(
ok,
isymbols_out,
osymbols_out,
total_weight,
) = kaldifst.get_linear_symbol_sequence(best_path)
if not ok:
print(f"failed to get linear symbol sequence for {filename}")
return [""]
# tokens are incremented during graph construction
# so they need to be decremented
hyps = [id2token[i - 1] for i in osymbols_out]
# hyps = "".join(hyps).split("▁")
hyps = "".join(hyps).split("\u2581") # unicode codepoint of ▁
return hyps
@torch.no_grad()
def main():
parser = get_parser()
args = parser.parse_args()
device = torch.device("cpu")
logging.info(f"device: {device}")
logging.info("Loading torchscript model")
model = torch.jit.load(args.nn_model)
model.eval()
model.to(device)
logging.info(f"Loading H from {args.H}")
H = kaldifst.StdVectorFst.read(args.H)
sample_rate = 16000
logging.info("Constructing Fbank computer")
opts = kaldifeat.FbankOptions()
opts.device = device
opts.frame_opts.dither = 0
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = sample_rate
opts.mel_opts.num_bins = 80
fbank = kaldifeat.Fbank(opts)
logging.info(f"Reading sound files: {args.sound_files}")
waves = read_sound_files(
filenames=args.sound_files, expected_sample_rate=sample_rate
)
waves = [w.to(device) for w in waves]
logging.info("Decoding started")
features = fbank(waves)
feature_lengths = [f.shape[0] for f in features]
feature_lengths = torch.tensor(feature_lengths)
supervisions = dict()
supervisions["sequence_idx"] = torch.arange(len(features))
supervisions["start_frame"] = torch.zeros(len(features))
supervisions["num_frames"] = feature_lengths
features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
nnet_output, _, _ = model(features, supervisions)
feature_lengths = ((feature_lengths - 1) // 2 - 1) // 2
id2token = read_tokens(args.tokens)
hyps = []
for i in range(nnet_output.shape[0]):
hyp = decode(
filename=args.sound_files[i],
nnet_output=nnet_output[i, : feature_lengths[i]],
H=H,
id2token=id2token,
)
hyps.append(hyp)
s = "\n"
for filename, hyp in zip(args.sound_files, hyps):
words = " ".join(hyp)
s += f"{filename}:\n{words}\n\n"
logging.info(s)
logging.info("Decoding Done")
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -0,0 +1,232 @@
#!/usr/bin/env python3
# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang)
"""
This file shows how to use a torchscript model for decoding with H
on CPU using OpenFST and decoders from kaldi.
Usage:
./conformer_ctc/jit_pretrained_decode_with_H.py \
--nn-model ./conformer_ctc/exp/cpu_jit.pt \
--HL ./data/lang_bpe_500/HL.fst \
--words ./data/lang_bpe_500/words.txt \
./download/LibriSpeech/test-clean/1089/134686/1089-134686-0002.flac \
./download/LibriSpeech/test-clean/1221/135766/1221-135766-0001.flac
Note that to generate ./conformer_ctc/exp/cpu_jit.pt,
you can use ./export.py --jit 1
"""
import argparse
import logging
import math
from typing import Dict, List
import kaldi_hmm_gmm
import kaldifeat
import kaldifst
import torch
import torchaudio
from kaldi_hmm_gmm import DecodableCtc, FasterDecoder, FasterDecoderOptions
from torch.nn.utils.rnn import pad_sequence
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--nn-model",
type=str,
required=True,
help="""Path to the torchscript model.
You can use ./conformer_ctc/export.py --jit 1
to obtain it
""",
)
parser.add_argument(
"--words",
type=str,
required=True,
help="Path to words.txt",
)
parser.add_argument("--HL", type=str, required=True, help="Path to HL.fst")
parser.add_argument(
"sound_files",
type=str,
nargs="+",
help="The input sound file(s) to transcribe. "
"Supported formats are those supported by torchaudio.load(). "
"For example, wav and flac are supported. ",
)
return parser
def read_words(words_txt: str) -> Dict[int, str]:
id2word = dict()
with open(words_txt, encoding="utf-8") as f:
for line in f:
word, idx = line.strip().split()
id2word[int(idx)] = word
return id2word
def read_sound_files(
filenames: List[str], expected_sample_rate: float
) -> List[torch.Tensor]:
"""Read a list of sound files into a list 1-D float32 torch tensors.
Args:
filenames:
A list of sound filenames.
expected_sample_rate:
The expected sample rate of the sound files.
Returns:
Return a list of 1-D float32 torch tensors.
"""
ans = []
for f in filenames:
wave, sample_rate = torchaudio.load(f)
if sample_rate != expected_sample_rate:
wave = torchaudio.functional.resample(
wave,
orig_freq=sample_rate,
new_freq=expected_sample_rate,
)
# We use only the first channel
ans.append(wave[0].contiguous())
return ans
def decode(
filename: str,
nnet_output: torch.Tensor,
HL: kaldifst,
id2word: Dict[int, str],
) -> List[str]:
"""
Args:
filename:
Path to the filename for decoding. Used for debugging.
nnet_output:
A 2-D float32 tensor of shape (num_frames, vocab_size). It
contains output from log_softmax.
HL:
The HL graph.
word2token:
A map mapping token ID to word string.
Returns:
Return a list of decoded words.
"""
logging.info(f"{filename}, {nnet_output.shape}")
decodable = DecodableCtc(nnet_output.cpu())
decoder_opts = FasterDecoderOptions(max_active=3000)
decoder = FasterDecoder(HL, decoder_opts)
decoder.decode(decodable)
if not decoder.reached_final():
print(f"failed to decode {filename}")
return [""]
ok, best_path = decoder.get_best_path()
(
ok,
isymbols_out,
osymbols_out,
total_weight,
) = kaldifst.get_linear_symbol_sequence(best_path)
if not ok:
print(f"failed to get linear symbol sequence for {filename}")
return [""]
# are shifted by 1 during graph construction
hyps = [id2word[i] for i in osymbols_out]
return hyps
@torch.no_grad()
def main():
parser = get_parser()
args = parser.parse_args()
device = torch.device("cpu")
logging.info(f"device: {device}")
logging.info("Loading torchscript model")
model = torch.jit.load(args.nn_model)
model.eval()
model.to(device)
logging.info(f"Loading HL from {args.HL}")
HL = kaldifst.StdVectorFst.read(args.HL)
sample_rate = 16000
logging.info("Constructing Fbank computer")
opts = kaldifeat.FbankOptions()
opts.device = device
opts.frame_opts.dither = 0
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = sample_rate
opts.mel_opts.num_bins = 80
fbank = kaldifeat.Fbank(opts)
logging.info(f"Reading sound files: {args.sound_files}")
waves = read_sound_files(
filenames=args.sound_files, expected_sample_rate=sample_rate
)
waves = [w.to(device) for w in waves]
logging.info("Decoding started")
features = fbank(waves)
feature_lengths = [f.shape[0] for f in features]
feature_lengths = torch.tensor(feature_lengths)
supervisions = dict()
supervisions["sequence_idx"] = torch.arange(len(features))
supervisions["start_frame"] = torch.zeros(len(features))
supervisions["num_frames"] = feature_lengths
features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
nnet_output, _, _ = model(features, supervisions)
feature_lengths = ((feature_lengths - 1) // 2 - 1) // 2
id2word = read_words(args.words)
hyps = []
for i in range(nnet_output.shape[0]):
hyp = decode(
filename=args.sound_files[i],
nnet_output=nnet_output[i, : feature_lengths[i]],
HL=HL,
id2word=id2word,
)
hyps.append(hyp)
s = "\n"
for filename, hyp in zip(args.sound_files, hyps):
words = " ".join(hyp)
s += f"{filename}:\n{words}\n\n"
logging.info(s)
logging.info("Decoding Done")
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -0,0 +1,127 @@
#!/usr/bin/env python3
# Copyright (c) 2023 Xiaomi Corporation (authors: Fangjun Kuang)
"""
This script takes as input lang_dir containing lexicon_disambig.txt,
tokens.txt, and words.txt and generates the following files:
- H.fst
- HL.fst
Note that saved files are in OpenFst binary format.
Usage:
./local/prepare_lang_fst.py \
--lang-dir ./data/lang_phone \
--has-silence 1
Or
./local/prepare_lang_fst.py \
--lang-dir ./data/lang_bpe_500
"""
import argparse
import logging
from pathlib import Path
import kaldifst
from icefall.ctc import (
Lexicon,
add_disambig_self_loops,
add_one,
build_standard_ctc_topo,
make_lexicon_fst_no_silence,
make_lexicon_fst_with_silence,
)
from icefall.utils import str2bool
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--lang-dir",
type=str,
help="""Input and output directory.
""",
)
parser.add_argument(
"--has-silence",
type=str2bool,
default=False,
help="True if the lexicon has silence.",
)
return parser.parse_args()
def main():
args = get_args()
lang_dir = args.lang_dir
lexicon = Lexicon(lang_dir)
logging.info("Building standard CTC topology")
max_token_id = max(lexicon.tokens)
H = build_standard_ctc_topo(max_token_id=max_token_id)
# We need to add one to all tokens since we want to use ID 0
# for epsilon
add_one(H, treat_ilabel_zero_specially=False, update_olabel=True)
H.write(f"{lang_dir}/H.fst")
logging.info("Building L")
# Now for HL
if args.has_silence:
L = make_lexicon_fst_with_silence(lexicon, attach_symbol_table=False)
else:
L = make_lexicon_fst_no_silence(lexicon, attach_symbol_table=False)
if args.has_silence:
# We also need to change the input labels of L
add_one(L, treat_ilabel_zero_specially=True, update_olabel=False)
else:
add_one(L, treat_ilabel_zero_specially=False, update_olabel=False)
# Invoke add_disambig_self_loops() so that it eats the disambig symbols
# from L after composition
add_disambig_self_loops(
H,
start=lexicon.token2id["#0"] + 1,
end=lexicon.max_disambig_id + 1,
)
with open("H_1.fst.txt", "w") as f:
print(H, file=f)
kaldifst.arcsort(H, sort_type="olabel")
kaldifst.arcsort(L, sort_type="ilabel")
logging.info("Building HL")
HL = kaldifst.compose(H, L)
kaldifst.determinize_star(HL)
disambig0 = lexicon.token2id["#0"] + 1
max_disambig = lexicon.max_disambig_id + 1
for state in kaldifst.StateIterator(HL):
for arc in kaldifst.ArcIterator(HL, state):
# If treat_ilabel_zero_specially is False, we always change it
# Otherwise, we only change non-zero input labels
if disambig0 <= arc.ilabel <= max_disambig:
arc.ilabel = 0
# Note: We are not composing L with G, so there is no need to add
# self-loops to L to handle #0
HL.write(f"{lang_dir}/HL.fst")
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -57,8 +57,7 @@ def test_model():
convert_scaled_to_non_scaled(model, inplace=True) convert_scaled_to_non_scaled(model, inplace=True)
if not os.path.exists(params.exp_dir): params.exp_dir.mkdir(exist_ok=True)
os.path.mkdir(params.exp_dir)
encoder_filename = params.exp_dir / "encoder_jit_trace.pt" encoder_filename = params.exp_dir / "encoder_jit_trace.pt"
export_encoder_model_jit_trace(model.encoder, encoder_filename) export_encoder_model_jit_trace(model.encoder, encoder_filename)

View File

@ -242,6 +242,10 @@ if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
$lang_dir/L_disambig.pt \ $lang_dir/L_disambig.pt \
$lang_dir/L_disambig.fst $lang_dir/L_disambig.fst
fi fi
if [ ! -f $lang_dir/HL.fst ]; then
./local/prepare_lang_fst.py --lang-dir $lang_dir
fi
done done
fi fi

View File

@ -0,0 +1 @@
../../../librispeech/ASR/local/prepare_lang_fst.py

View File

@ -60,6 +60,7 @@ if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
) > $lang_dir/lexicon.txt ) > $lang_dir/lexicon.txt
./local/prepare_lang.py ./local/prepare_lang.py
./local/prepare_lang_fst.py --lang-dir ./data/lang_phone --has-silence 1
fi fi
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then

View File

@ -156,7 +156,6 @@ def main():
features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
# Note: We don't use key padding mask for attention during decoding
nnet_output = model(features) nnet_output = model(features)
batch_size = nnet_output.shape[0] batch_size = nnet_output.shape[0]

View File

@ -0,0 +1,208 @@
#!/usr/bin/env python3
# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang)
"""
This file shows how to use a torchscript model for decoding with H
on CPU using OpenFST and decoders from kaldi.
Usage:
./tdnn/jit_pretrained_decode_with_H.py \
--nn-model ./tdnn/exp/cpu_jit.pt \
--H ./data/lang_phone/H.fst \
--tokens ./data/lang_phone/tokens.txt \
./download/waves_yesno/0_0_0_1_0_0_0_1.wav \
./download/waves_yesno/0_0_1_0_0_0_1_0.wav \
./download/waves_yesno/0_0_1_0_0_1_1_1.wav
Note that to generate ./tdnn/exp/cpu_jit.pt,
you can use ./export.py --jit 1
"""
import argparse
import logging
import math
from typing import Dict, List
import kaldifeat
import kaldifst
import torch
import torchaudio
from kaldi_hmm_gmm import DecodableCtc, FasterDecoder, FasterDecoderOptions
from torch.nn.utils.rnn import pad_sequence
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--nn-model",
type=str,
required=True,
help="""Path to the torchscript model.
You can use ./tdnn/export.py --jit 1
to obtain it
""",
)
parser.add_argument(
"--tokens",
type=str,
required=True,
help="Path to tokens.txt",
)
parser.add_argument("--H", type=str, required=True, help="Path to H.fst")
parser.add_argument(
"sound_files",
type=str,
nargs="+",
help="The input sound file(s) to transcribe. "
"Supported formats are those supported by torchaudio.load(). "
"For example, wav and flac are supported. ",
)
return parser
def read_tokens(tokens_txt: str) -> Dict[int, str]:
id2token = dict()
with open(tokens_txt, encoding="utf-8") as f:
for line in f:
token, idx = line.strip().split()
id2token[int(idx)] = token
return id2token
def read_sound_files(
filenames: List[str], expected_sample_rate: float
) -> List[torch.Tensor]:
"""Read a list of sound files into a list 1-D float32 torch tensors.
Args:
filenames:
A list of sound filenames.
expected_sample_rate:
The expected sample rate of the sound files.
Returns:
Return a list of 1-D float32 torch tensors.
"""
ans = []
for f in filenames:
wave, sample_rate = torchaudio.load(f)
if sample_rate != expected_sample_rate:
wave = torchaudio.functional.resample(
wave,
orig_freq=sample_rate,
new_freq=expected_sample_rate,
)
# We use only the first channel
ans.append(wave[0].contiguous())
return ans
def decode(
filename: str,
nnet_output: torch.Tensor,
H: kaldifst,
id2token: Dict[int, str],
) -> List[str]:
decodable = DecodableCtc(nnet_output)
decoder_opts = FasterDecoderOptions(max_active=3000)
decoder = FasterDecoder(H, decoder_opts)
decoder.decode(decodable)
if not decoder.reached_final():
print(f"failed to decode {filename}")
return [""]
ok, best_path = decoder.get_best_path()
(
ok,
isymbols_out,
osymbols_out,
total_weight,
) = kaldifst.get_linear_symbol_sequence(best_path)
if not ok:
print(f"failed to get linear symbol sequence for {filename}")
return [""]
# are shifted by 1 during graph construction
hyps = [id2token[i - 1] for i in osymbols_out if id2token[i - 1] != "SIL"]
return hyps
@torch.no_grad()
def main():
parser = get_parser()
args = parser.parse_args()
device = torch.device("cpu")
logging.info(f"device: {device}")
logging.info("Loading torchscript model")
model = torch.jit.load(args.nn_model)
model.eval()
model.to(device)
logging.info(f"Loading H from {args.H}")
H = kaldifst.StdVectorFst.read(args.H)
sample_rate = 8000
logging.info("Constructing Fbank computer")
opts = kaldifeat.FbankOptions()
opts.device = device
opts.frame_opts.dither = 0
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = sample_rate
opts.mel_opts.num_bins = 23
fbank = kaldifeat.Fbank(opts)
logging.info(f"Reading sound files: {args.sound_files}")
waves = read_sound_files(
filenames=args.sound_files, expected_sample_rate=sample_rate
)
waves = [w.to(device) for w in waves]
logging.info("Decoding started")
features = fbank(waves)
features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
nnet_output = model(features)
id2token = read_tokens(args.tokens)
hyps = []
for i in range(nnet_output.shape[0]):
hyp = decode(
filename=args.sound_files[0],
nnet_output=nnet_output[i],
H=H,
id2token=id2token,
)
hyps.append(hyp)
s = "\n"
for filename, hyp in zip(args.sound_files, hyps):
words = " ".join(hyp)
s += f"{filename}:\n{words}\n\n"
logging.info(s)
logging.info("Decoding Done")
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -0,0 +1,207 @@
#!/usr/bin/env python3
# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang)
"""
This file shows how to use a torchscript model for decoding with HL
on CPU using OpenFST and decoders from kaldi.
Usage:
./tdnn/jit_pretrained_decode_with_HL.py \
--nn-model ./tdnn/exp/cpu_jit.pt \
--HL ./data/lang_phone/HL.fst \
--words ./data/lang_phone/words.txt \
./download/waves_yesno/0_0_0_1_0_0_0_1.wav \
./download/waves_yesno/0_0_1_0_0_0_1_0.wav \
./download/waves_yesno/0_0_1_0_0_1_1_1.wav
Note that to generate ./tdnn/exp/cpu_jit.pt,
you can use ./export.py --jit 1
"""
import argparse
import logging
import math
from typing import Dict, List
import kaldifeat
import kaldifst
import torch
import torchaudio
from kaldi_hmm_gmm import DecodableCtc, FasterDecoder, FasterDecoderOptions
from torch.nn.utils.rnn import pad_sequence
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--nn-model",
type=str,
required=True,
help="""Path to the torchscript model.
You can use ./tdnn/export.py --jit 1
to obtain it
""",
)
parser.add_argument(
"--words",
type=str,
required=True,
help="Path to words.txt",
)
parser.add_argument("--HL", type=str, required=True, help="Path to HL.fst")
parser.add_argument(
"sound_files",
type=str,
nargs="+",
help="The input sound file(s) to transcribe. "
"Supported formats are those supported by torchaudio.load(). "
"For example, wav and flac are supported. ",
)
return parser
def read_words(words_txt: str) -> Dict[int, str]:
id2word = dict()
with open(words_txt, encoding="utf-8") as f:
for line in f:
word, idx = line.strip().split()
id2word[int(idx)] = word
return id2word
def read_sound_files(
filenames: List[str], expected_sample_rate: float
) -> List[torch.Tensor]:
"""Read a list of sound files into a list 1-D float32 torch tensors.
Args:
filenames:
A list of sound filenames.
expected_sample_rate:
The expected sample rate of the sound files.
Returns:
Return a list of 1-D float32 torch tensors.
"""
ans = []
for f in filenames:
wave, sample_rate = torchaudio.load(f)
if sample_rate != expected_sample_rate:
wave = torchaudio.functional.resample(
wave,
orig_freq=sample_rate,
new_freq=expected_sample_rate,
)
# We use only the first channel
ans.append(wave[0].contiguous())
return ans
def decode(
filename: str,
nnet_output: torch.Tensor,
HL: kaldifst,
id2word: Dict[int, str],
) -> List[str]:
decodable = DecodableCtc(nnet_output)
decoder_opts = FasterDecoderOptions(max_active=3000)
decoder = FasterDecoder(HL, decoder_opts)
decoder.decode(decodable)
if not decoder.reached_final():
print(f"failed to decode {filename}")
return [""]
ok, best_path = decoder.get_best_path()
(
ok,
isymbols_out,
osymbols_out,
total_weight,
) = kaldifst.get_linear_symbol_sequence(best_path)
if not ok:
print(f"failed to get linear symbol sequence for {filename}")
return [""]
hyps = [id2word[i] for i in osymbols_out if id2word[i] != "<SIL>"]
return hyps
@torch.no_grad()
def main():
parser = get_parser()
args = parser.parse_args()
device = torch.device("cpu")
logging.info(f"device: {device}")
logging.info("Loading torchscript model")
model = torch.jit.load(args.nn_model)
model.eval()
model.to(device)
logging.info(f"Loading HL from {args.HL}")
HL = kaldifst.StdVectorFst.read(args.HL)
sample_rate = 8000
logging.info("Constructing Fbank computer")
opts = kaldifeat.FbankOptions()
opts.device = device
opts.frame_opts.dither = 0
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = sample_rate
opts.mel_opts.num_bins = 23
fbank = kaldifeat.Fbank(opts)
logging.info(f"Reading sound files: {args.sound_files}")
waves = read_sound_files(
filenames=args.sound_files, expected_sample_rate=sample_rate
)
waves = [w.to(device) for w in waves]
logging.info("Decoding started")
features = fbank(waves)
features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
nnet_output = model(features)
id2word = read_words(args.words)
hyps = []
for i in range(nnet_output.shape[0]):
hyp = decode(
filename=args.sound_files[0],
nnet_output=nnet_output[i],
HL=HL,
id2word=id2word,
)
hyps.append(hyp)
s = "\n"
for filename, hyp in zip(args.sound_files, hyps):
words = " ".join(hyp)
s += f"{filename}:\n{words}\n\n"
logging.info(s)
logging.info("Decoding Done")
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

2
icefall/ctc/.gitignore vendored Normal file
View File

@ -0,0 +1,2 @@
*.pdf
*.gv

17
icefall/ctc/README.md Normal file
View File

@ -0,0 +1,17 @@
# Introduction
This folder uses [kaldifst][kaldifst] for graph construction
and decoders from [kaldi-hmm-gmm][kaldi-hmm-gmm] for CTC decoding.
It supports only `CPU`.
You can use
```bash
pip install kaldifst kaldi-hmm-gmm
```
to install the dependencies.
[kaldi-hmm-gmm]: https://github.com/csukuangfj/kaldi-hmm-gmm
[kaldifst]: https://github.com/k2-fsa/kaldifst
[k2]: https://github.com/k2-fsa/k2

6
icefall/ctc/__init__.py Normal file
View File

@ -0,0 +1,6 @@
from .prepare_lang import (
Lexicon,
make_lexicon_fst_no_silence,
make_lexicon_fst_with_silence,
)
from .topo import add_disambig_self_loops, add_one, build_standard_ctc_topo

334
icefall/ctc/prepare_lang.py Normal file
View File

@ -0,0 +1,334 @@
# Copyright 2023 Xiaomi Corp. (author: Fangjun Kuang)
"""
The lang_dir should contain the following files:
- "lexicon_disambig.txt"
- "tokens.txt"
- "words.txt"
"""
import math
from collections import defaultdict
from pathlib import Path
from typing import List, Tuple
import kaldifst
import re
class Lexicon:
"""Once constructed it is immutable"""
def __init__(
self,
lang_dir: str,
disambig_pattern: str = re.compile(r"^#\d+$"),
):
"""
Args:
lang_dir:
The path to the lang directory. We expect that it contains the
following files:
- lexicon_disambig.txt
- tokens.txt
- words.txt
The format of the above files is described below.
(1) lexicon_disambig.txt
Each line in the lexicon_disambig.txt has the following format:
word token1 token2 ... tokenN
That is, the first field is the word, the remaining fields are
pronunciations of this word. Fields are separated by space(s).
(2) tokens.txt
Each line in tokens.txt has two fields separated by space(s):
token ID
The first field is the token symbol and the second filed is the
integer ID of the token.
(3) words.txt
Each line in words.txt has two fields separated by space(s):
word ID
The first field is the word symbol and the second filed is the
integer ID of the word.
disambig_pattern:
It contains the pattern for disambiguation symbols.
"""
lang_dir = Path(lang_dir)
lexicon_txt = lang_dir / "lexicon_disambig.txt"
tokens_txt = lang_dir / "tokens.txt"
words_txt = lang_dir / "words.txt"
assert lexicon_txt.is_file(), lexicon_txt
assert tokens_txt.is_file(), tokens_txt
assert words_txt.is_file(), words_txt
self._read_lexicon(lexicon_txt)
self._read_tokens(tokens_txt)
self._read_words(words_txt)
self.disambig_pattern = disambig_pattern
max_disambig_id = -1
for s, i in self.token2id.items():
if self.disambig_pattern.match(s) and i > max_disambig_id:
max_disambig_id = i
self.max_disambig_id = max_disambig_id
def _read_lexicon(self, lexicon_txt: str):
word2phones = defaultdict(list)
with open(lexicon_txt, encoding="utf-8") as f:
for line in f:
word_phones = line.strip().split()
assert len(word_phones) >= 2, (word_phones, line)
word = word_phones[0]
phones: str = " ".join(word_phones[1:])
word2phones[word].append(phones)
# We use a list here since a word may have multiple
# pronunciations
self.word2phones = word2phones
def _read_tokens(self, tokens_txt):
token2id = dict()
id2token = dict()
with open(tokens_txt, encoding="utf-8") as f:
for line in f:
token_id = line.strip().split()
assert len(token_id) == 2, token_id
token = token_id[0]
idx = int(token_id[1])
assert token not in token2id, f"Duplicate token {line}"
assert idx not in id2token, f"Duplicate ID {line}"
token2id[token] = idx
id2token[idx] = token
self.token2id = token2id
self.id2token = id2token
def _read_words(self, words_txt):
word2id = dict()
id2word = dict()
with open(words_txt, encoding="utf-8") as f:
for line in f:
word_id = line.strip().split()
assert len(word_id) == 2, word_id
word = word_id[0]
idx = int(word_id[1])
assert word not in word2id, f"Duplicate token {line}"
assert idx not in id2word, f"Duplicate ID {line}"
word2id[word] = idx
id2word[idx] = word
self.word2id = word2id
self.id2word = id2word
def __iter__(self) -> Tuple[str, List[str]]:
for word, phones_list in self.word2phones.items():
for phones in phones_list:
yield word, phones
def __str__(self):
return str(self.word2phones)
@property
def tokens(self) -> List[int]:
"""Return a list of token IDs excluding those from
disambiguation symbols.
Caution:
0 is not a token ID so it is excluded from the return value.
"""
ans = []
for s in self.token2id:
if not self.disambig_pattern.match(s):
ans.append(self.token2id[s])
if 0 in ans:
ans.remove(0)
ans.sort()
return ans
# See also
# http://vpanayotov.blogspot.com/2012/06/kaldi-decoding-graph-construction.html
def make_lexicon_fst_with_silence(
lexicon: Lexicon,
sil_prob: float = 0.5,
sil_phone: str = "SIL",
attach_symbol_table: bool = True,
) -> kaldifst.StdVectorFst:
phone2id = lexicon.token2id
word2id = lexicon.word2id
assert sil_phone in phone2id
assert sil_phone in phone2id, sil_phone
sil_cost = -1 * math.log(sil_prob)
no_sil_cost = -1 * math.log(1.0 - sil_prob)
fst = kaldifst.StdVectorFst()
start_state = fst.add_state()
loop_state = fst.add_state()
sil_state = fst.add_state()
fst.start = start_state
fst.set_final(state=loop_state, weight=0)
fst.add_arc(
state=start_state,
arc=kaldifst.StdArc(
ilabel=0,
olabel=0,
weight=no_sil_cost,
nextstate=loop_state,
),
)
fst.add_arc(
state=start_state,
arc=kaldifst.StdArc(
ilabel=0,
olabel=0,
weight=sil_cost,
nextstate=sil_state,
),
)
fst.add_arc(
state=sil_state,
arc=kaldifst.StdArc(
ilabel=phone2id[sil_phone],
olabel=0,
weight=0,
nextstate=loop_state,
),
)
for word, phones in lexicon:
phoneseq = phones.split()
pron_cost = 0
cur_state = loop_state
for i in range(len(phoneseq) - 1):
next_state = fst.add_state()
fst.add_arc(
state=cur_state,
arc=kaldifst.StdArc(
ilabel=phone2id[phoneseq[i]],
olabel=word2id[word] if i == 0 else 0,
weight=pron_cost if i == 0 else 0,
nextstate=next_state,
),
)
cur_state = next_state
i = len(phoneseq) - 1 # note: i == -1 if phoneseq is empty.
fst.add_arc(
state=cur_state,
arc=kaldifst.StdArc(
ilabel=phone2id[phoneseq[i]] if i >= 0 else 0,
olabel=word2id[word] if i <= 0 else 0,
weight=no_sil_cost + (pron_cost if i <= 0 else 0),
nextstate=loop_state,
),
)
fst.add_arc(
state=cur_state,
arc=kaldifst.StdArc(
ilabel=phone2id[phoneseq[i]] if i >= 0 else 0,
olabel=word2id[word] if i <= 0 else 0,
weight=sil_cost + (pron_cost if i <= 0 else 0),
nextstate=sil_state,
),
)
if attach_symbol_table:
isym = kaldifst.SymbolTable()
for p, i in phone2id.items():
isym.add_symbol(symbol=p, key=i)
fst.input_symbols = isym
osym = kaldifst.SymbolTable()
for w, i in word2id.items():
osym.add_symbol(symbol=w, key=i)
fst.output_symbols = osym
return fst
def make_lexicon_fst_no_silence(
lexicon: Lexicon,
attach_symbol_table: bool = True,
) -> kaldifst.StdVectorFst:
phone2id = lexicon.token2id
word2id = lexicon.word2id
fst = kaldifst.StdVectorFst()
start_state = fst.add_state()
fst.start = start_state
fst.set_final(state=start_state, weight=0)
for word, phones in lexicon:
phoneseq = phones.split()
pron_cost = 0
cur_state = start_state
for i in range(len(phoneseq) - 1):
next_state = fst.add_state()
fst.add_arc(
state=cur_state,
arc=kaldifst.StdArc(
ilabel=phone2id[phoneseq[i]],
olabel=word2id[word] if i == 0 else 0,
weight=pron_cost if i == 0 else 0,
nextstate=next_state,
),
)
cur_state = next_state
i = len(phoneseq) - 1 # note: i == -1 if phoneseq is empty.
fst.add_arc(
state=cur_state,
arc=kaldifst.StdArc(
ilabel=phone2id[phoneseq[i]] if i >= 0 else 0,
olabel=word2id[word] if i <= 0 else 0,
weight=pron_cost if i <= 0 else 0,
nextstate=start_state,
),
)
if attach_symbol_table:
isym = kaldifst.SymbolTable()
for p, i in phone2id.items():
isym.add_symbol(symbol=p, key=i)
fst.input_symbols = isym
osym = kaldifst.SymbolTable()
for w, i in word2id.items():
osym.add_symbol(symbol=w, key=i)
fst.output_symbols = osym
return fst

140
icefall/ctc/test_ctc_topo.py Executable file
View File

@ -0,0 +1,140 @@
#!/usr/bin/env python3
# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang)
from pathlib import Path
import graphviz
import kaldifst
import sentencepiece as spm
from prepare_lang import (
Lexicon,
make_lexicon_fst_no_silence,
make_lexicon_fst_with_silence,
)
from topo import add_disambig_self_loops, add_one, build_standard_ctc_topo
def test_yesno():
lang_dir = "/Users/fangjun/open-source/icefall/egs/yesno/ASR/data/lang_phone"
if not Path(lang_dir).is_dir():
print(f"{lang_dir} does not exist! Skip testing")
return
lexicon = Lexicon(lang_dir)
max_token_id = max(lexicon.tokens)
H = build_standard_ctc_topo(max_token_id=max_token_id)
isym = kaldifst.SymbolTable()
isym.add_symbol(symbol="<blk>", key=0)
for i in range(1, max_token_id + 1):
isym.add_symbol(symbol=lexicon.id2token[i], key=i)
osym = kaldifst.SymbolTable()
osym.add_symbol(symbol="<eps>", key=0)
for i in range(1, max_token_id + 1):
osym.add_symbol(symbol=lexicon.id2token[i], key=i)
H.input_symbols = isym
H.output_symbols = osym
fst_dot = kaldifst.draw(H, acceptor=False, portrait=True)
source = graphviz.Source(fst_dot)
source.render(outfile="standard_ctc_topo_yesno.pdf")
# See the link below to visualize the above PDF
# https://t.ly/7uXZ9
# Now test HL
# We need to add one to all tokens since we want to use ID 0
# for epsilon
add_one(H, treat_ilabel_zero_specially=False, update_olabel=True)
add_disambig_self_loops(
H,
start=lexicon.token2id["#0"] + 1,
end=lexicon.max_disambig_id,
)
fst_dot = kaldifst.draw(H, acceptor=False, portrait=True)
source = graphviz.Source(fst_dot)
source.render(outfile="standard_ctc_topo_disambig_yesno.pdf")
L = make_lexicon_fst_with_silence(lexicon)
# We also need to change the input labels of L
add_one(L, treat_ilabel_zero_specially=True, update_olabel=False)
H.output_symbols = None
kaldifst.arcsort(H, sort_type="olabel")
kaldifst.arcsort(L, sort_type="ilabel")
HL = kaldifst.compose(H, L)
lexicon.id2token[0] = "<blk>"
lexicon.token2id["<blk>"] = 0
isym = kaldifst.SymbolTable()
isym.add_symbol(symbol="<eps>", key=0)
for i in range(0, lexicon.max_disambig_id + 1):
isym.add_symbol(symbol=lexicon.id2token[i], key=i + 1)
osym = kaldifst.SymbolTable()
for i, word in lexicon.id2word.items():
osym.add_symbol(symbol=word, key=i)
HL.input_symbols = isym
HL.output_symbols = osym
fst_dot = kaldifst.draw(HL, acceptor=False, portrait=True)
source = graphviz.Source(fst_dot)
source.render(outfile="HL_yesno.pdf")
def test_librispeech():
lang_dir = (
"/star-fj/fangjun/open-source/icefall-2/egs/librispeech/ASR/data/lang_bpe_500"
)
if not Path(lang_dir).is_dir():
print(f"{lang_dir} does not exist! Skip testing")
return
lexicon = Lexicon(lang_dir)
HL = kaldifst.StdVectorFst.read(lang_dir + "/HL.fst")
sp = spm.SentencePieceProcessor()
sp.load(lang_dir + "/bpe.model")
i = lexicon.word2id["HELLOA"]
k = lexicon.word2id["WORLD"]
print(i, k)
s = f"""
0 1 {i} {i}
1 2 {k} {k}
2
"""
fst = kaldifst.compile(
s=s,
acceptor=False,
)
L = make_lexicon_fst_no_silence(lexicon, attach_symbol_table=False)
kaldifst.arcsort(L, sort_type="olabel")
with open("L.fst.txt", "w") as f:
print(L, file=f)
fst = kaldifst.compose(L, fst)
print(fst)
fst_dot = kaldifst.draw(fst, acceptor=False, portrait=True)
source = graphviz.Source(fst_dot)
source.render(outfile="a.pdf")
print(sp.encode(["HELLOA", "WORLD"]))
def main():
test_yesno()
test_librispeech()
if __name__ == "__main__":
main()

View File

@ -0,0 +1,43 @@
#!/usr/bin/env python3
# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang)
from pathlib import Path
import graphviz
import kaldifst
from prepare_lang import Lexicon, make_lexicon_fst_with_silence
def test_yesno():
lang_dir = "/Users/fangjun/open-source/icefall/egs/yesno/ASR/data/lang_phone"
if not Path(lang_dir).is_dir():
print(f"{lang_dir} does not exist! Skip testing")
return
lexicon = Lexicon(lang_dir)
L = make_lexicon_fst_with_silence(lexicon)
isym = kaldifst.SymbolTable()
for i, token in lexicon.id2token.items():
isym.add_symbol(symbol=token, key=i)
osym = kaldifst.SymbolTable()
for i, word in lexicon.id2word.items():
osym.add_symbol(symbol=word, key=i)
L.input_symbols = isym
L.output_symbols = osym
fst_dot = kaldifst.draw(L, acceptor=False, portrait=True)
source = graphviz.Source(fst_dot)
source.render(outfile="L_yesno.pdf")
# See the link below to visualize the above PDF
# https://t.ly/jMfXW
def main():
test_yesno()
if __name__ == "__main__":
main()

137
icefall/ctc/topo.py Normal file
View File

@ -0,0 +1,137 @@
# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang)
import kaldifst
# Note the name contains `standard`; it means there will be non-standard
# topologies.
def build_standard_ctc_topo(max_token_id: int) -> kaldifst.StdVectorFst:
"""Build a standard CTC topology.
Args:
Maximum valid token ID. We assume token IDs are contiguous
and starts from 0. In other words, the vocabulary size is
``max_token_id + 1``. We assume the ID of the blank symbol is 0.
"""
# Token ID starts from 0 and there are as many states as the
# number of tokens.
#
# Note that epsilon is not a token and the token with ID 0 in tokens.txt
# is not an epsilon. It means input label 0 of the resulting FST does
# not represent an epsilon.
#
# You can use the function `add_one()` to modify the input/output labels
# of the resulting FST
num_states = max_token_id + 1
# Step 1: Create as many states as the number of tokens.
# Each state is a final state
fst = kaldifst.StdVectorFst()
for i in range(num_states):
s = fst.add_state()
fst.set_final(state=s, weight=0)
# Step 2: Set state 0 as the start state.
# We assume the ID of the blank symbol is 0.
fst.start = 0
# Step 3: Build a fully connected graph.
for i in range(num_states):
for k in range(num_states):
fst.add_arc(
state=i,
arc=kaldifst.StdArc(
ilabel=k,
olabel=k if i != k else 0, # if i==k, it is a self loop
weight=0,
nextstate=k,
),
)
# Please see ./test_ctc_topo.py if you want to know what the resulting
# FST looks like
return fst
def add_one(
fst: kaldifst.StdVectorFst,
treat_ilabel_zero_specially: bool,
update_olabel: bool,
) -> None:
"""Modify the input and output labels of the given FST in-place.
Args:
fst:
The FST to be modified. It is changed in-place.
treat_ilabel_zero_specially:
If True, then every non-zero input label is increased by one and the
zero input label is not changed.
If False, then every input label is increased by one.
update_olabel:
If False, the output label is not changed.
If True, then every non-zero output label is increased by one.
In either case, output label with 0 is not changed.
"""
for state in kaldifst.StateIterator(fst):
for arc in kaldifst.ArcIterator(fst, state):
# If treat_ilabel_zero_specially is False, we always change it
# Otherwise, we only change non-zero input labels
if treat_ilabel_zero_specially is False or arc.ilabel != 0:
arc.ilabel += 1
if update_olabel and arc.olabel != 0:
arc.olabel += 1
if fst.input_symbols is not None:
input_symbols = kaldifst.SymbolTable()
input_symbols.add_symbol(symbol="<eps>", key=0)
for i in range(0, fst.input_symbols.num_symbols()):
s = fst.input_symbols.find(i)
input_symbols.add_symbol(symbol=s, key=i + 1)
fst.input_symbols = input_symbols
if update_olabel and fst.output_symbols is not None:
output_symbols = kaldifst.SymbolTable()
output_symbols.add_symbol(symbol="<eps>", key=0)
for i in range(0, fst.output_symbols.num_symbols()):
s = fst.output_symbols.find(i)
output_symbols.add_symbol(symbol=s, key=i + 1)
fst.output_symbols = output_symbols
def add_disambig_self_loops(fst: kaldifst.StdVectorFst, start: int, end: int):
"""Add self-loops to each state.
For each disambig symbol, we add a self-loop with input label disambig_id
and output label diambig_id of that disambig symbol.
Args:
fst:
It is changed in-place.
start:
The ID of #0
end:
The ID of the last disambig symbol. For instance if there are 3
disambig symbols ``#0``, ``#1``, and ``#2``, then ``end`` is the ID
of ``#2``.
"""
for state in kaldifst.StateIterator(fst):
for i in range(start, end + 1):
fst.add_arc(
state=state,
arc=kaldifst.StdArc(
ilabel=i,
olabel=i,
weight=0,
nextstate=state,
),
)
if fst.output_symbols:
for i in range(start, end + 1):
fst.output_symbols.add_symbol(symbol=f"#{i-start}", key=i)

View File

@ -27,3 +27,4 @@ onnx
onnxmltools onnxmltools
onnxruntime onnxruntime
kaldifst kaldifst
kaldi-hmm-gmm

View File

@ -1,6 +1,7 @@
kaldifst kaldifst
kaldilm kaldilm
kaldialign kaldialign
kaldi-hmm-gmm
sentencepiece>=0.1.96 sentencepiece>=0.1.96
tensorboard tensorboard
typeguard typeguard