From 2318c3fbd011b14ceffe8b3a8663057708afeea0 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Tue, 26 Sep 2023 16:36:19 +0800 Subject: [PATCH] Support CTC decoding on CPU using OpenFst and kaldi decoders. (#1244) --- .flake8 | 1 + .../scripts/run-pre-trained-conformer-ctc.sh | 43 +++ .../run-pretrained-conformer-ctc.yml | 2 +- .github/workflows/run-yesno-recipe.yml | 37 ++ .gitignore | 2 + docs/source/model-export/export-ncnn.rst | 2 + .../jit_pretrained_decode_with_H.py | 235 ++++++++++++ .../jit_pretrained_decode_with_HL.py | 232 ++++++++++++ egs/librispeech/ASR/local/prepare_lang_fst.py | 127 +++++++ .../lstm_transducer_stateless/test_model.py | 3 +- egs/librispeech/ASR/prepare.sh | 4 + egs/yesno/ASR/local/prepare_lang_fst.py | 1 + egs/yesno/ASR/prepare.sh | 1 + egs/yesno/ASR/tdnn/jit_pretrained.py | 1 - .../ASR/tdnn/jit_pretrained_decode_with_H.py | 208 +++++++++++ .../ASR/tdnn/jit_pretrained_decode_with_HL.py | 207 +++++++++++ icefall/ctc/.gitignore | 2 + icefall/ctc/README.md | 17 + icefall/ctc/__init__.py | 6 + icefall/ctc/prepare_lang.py | 334 ++++++++++++++++++ icefall/ctc/test_ctc_topo.py | 140 ++++++++ icefall/ctc/test_prepare_lang.py | 43 +++ icefall/ctc/topo.py | 137 +++++++ requirements-ci.txt | 1 + requirements.txt | 1 + 25 files changed, 1783 insertions(+), 4 deletions(-) create mode 100755 egs/librispeech/ASR/conformer_ctc/jit_pretrained_decode_with_H.py create mode 100755 egs/librispeech/ASR/conformer_ctc/jit_pretrained_decode_with_HL.py create mode 100755 egs/librispeech/ASR/local/prepare_lang_fst.py create mode 120000 egs/yesno/ASR/local/prepare_lang_fst.py create mode 100755 egs/yesno/ASR/tdnn/jit_pretrained_decode_with_H.py create mode 100755 egs/yesno/ASR/tdnn/jit_pretrained_decode_with_HL.py create mode 100644 icefall/ctc/.gitignore create mode 100644 icefall/ctc/README.md create mode 100644 icefall/ctc/__init__.py create mode 100644 icefall/ctc/prepare_lang.py create mode 100755 icefall/ctc/test_ctc_topo.py create mode 100755 icefall/ctc/test_prepare_lang.py create mode 100644 icefall/ctc/topo.py diff --git a/.flake8 b/.flake8 index 1c0c2cdbb..410cb5482 100644 --- a/.flake8 +++ b/.flake8 @@ -24,6 +24,7 @@ exclude = **/data/**, icefall/shared/make_kn_lm.py, icefall/__init__.py + icefall/ctc/__init__.py ignore = # E203 white space before ":" diff --git a/.github/scripts/run-pre-trained-conformer-ctc.sh b/.github/scripts/run-pre-trained-conformer-ctc.sh index a4959aa01..19cbd96fc 100755 --- a/.github/scripts/run-pre-trained-conformer-ctc.sh +++ b/.github/scripts/run-pre-trained-conformer-ctc.sh @@ -44,3 +44,46 @@ log "HLG decoding" $repo/test_wavs/1089-134686-0001.flac \ $repo/test_wavs/1221-135766-0001.flac \ $repo/test_wavs/1221-135766-0002.flac + +log "CTC decoding on CPU with kaldi decoders using OpenFst" + +log "Exporting model with torchscript" + +pushd $repo/exp +ln -s pretrained.pt epoch-99.pt +popd + +./conformer_ctc/export.py \ + --epoch 99 \ + --avg 1 \ + --exp-dir $repo/exp \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + --jit 1 + +ls -lh $repo/exp + + +log "Generating H.fst, HL.fst" + +./local/prepare_lang_fst.py --lang-dir $repo/data/lang_bpe_500 +ls -lh $repo/data/lang_bpe_500 + +log "Decoding with H on CPU with OpenFst" + +./conformer_ctc/jit_pretrained_decode_with_H.py \ + --nn-model $repo/exp/cpu_jit.pt \ + --H $repo/data/lang_bpe_500/H.fst \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + $repo/test_wavs/1089-134686-0001.flac \ + $repo/test_wavs/1221-135766-0001.flac \ + $repo/test_wavs/1221-135766-0002.flac + +log "Decoding with HL on CPU with OpenFst" + +./conformer_ctc/jit_pretrained_decode_with_HL.py \ + --nn-model $repo/exp/cpu_jit.pt \ + --HL $repo/data/lang_bpe_500/HL.fst \ + --words $repo/data/lang_bpe_500/words.txt \ + $repo/test_wavs/1089-134686-0001.flac \ + $repo/test_wavs/1221-135766-0001.flac \ + $repo/test_wavs/1221-135766-0002.flac diff --git a/.github/workflows/run-pretrained-conformer-ctc.yml b/.github/workflows/run-pretrained-conformer-ctc.yml index 6151a5a14..e268d840d 100644 --- a/.github/workflows/run-pretrained-conformer-ctc.yml +++ b/.github/workflows/run-pretrained-conformer-ctc.yml @@ -29,7 +29,7 @@ concurrency: jobs: run_pre_trained_conformer_ctc: - if: github.event.label.name == 'ready' || github.event_name == 'push' + if: github.event.label.name == 'ready' || github.event_name == 'push' || github.event.label.name == 'ctc' runs-on: ${{ matrix.os }} strategy: matrix: diff --git a/.github/workflows/run-yesno-recipe.yml b/.github/workflows/run-yesno-recipe.yml index 57f15fe87..400595749 100644 --- a/.github/workflows/run-yesno-recipe.yml +++ b/.github/workflows/run-yesno-recipe.yml @@ -140,9 +140,46 @@ jobs: download/waves_yesno/0_0_0_1_0_0_0_1.wav \ download/waves_yesno/0_0_1_0_0_0_1_0.wav + - name: Test decoding with H + shell: bash + working-directory: ${{github.workspace}} + run: | + export PYTHONPATH=$PWD:$PYTHONPATH + echo $PYTHONPATH + + cd egs/yesno/ASR + python3 ./tdnn/export.py --epoch 14 --avg 2 --jit 1 + + python3 ./tdnn/jit_pretrained_decode_with_H.py \ + --nn-model ./tdnn/exp/cpu_jit.pt \ + --H ./data/lang_phone/H.fst \ + --tokens ./data/lang_phone/tokens.txt \ + ./download/waves_yesno/0_0_0_1_0_0_0_1.wav \ + ./download/waves_yesno/0_0_1_0_0_0_1_0.wav \ + ./download/waves_yesno/0_0_1_0_0_1_1_1.wav + + - name: Test decoding with HL + shell: bash + working-directory: ${{github.workspace}} + run: | + export PYTHONPATH=$PWD:$PYTHONPATH + echo $PYTHONPATH + + cd egs/yesno/ASR + python3 ./tdnn/export.py --epoch 14 --avg 2 --jit 1 + + python3 ./tdnn/jit_pretrained_decode_with_HL.py \ + --nn-model ./tdnn/exp/cpu_jit.pt \ + --HL ./data/lang_phone/HL.fst \ + --words ./data/lang_phone/words.txt \ + ./download/waves_yesno/0_0_0_1_0_0_0_1.wav \ + ./download/waves_yesno/0_0_1_0_0_0_1_0.wav \ + ./download/waves_yesno/0_0_1_0_0_1_1_1.wav + - name: Show generated files shell: bash working-directory: ${{github.workspace}} run: | cd egs/yesno/ASR ls -lh tdnn/exp + ls -lh data/lang_phone diff --git a/.gitignore b/.gitignore index 8af05d884..fa18ca83c 100644 --- a/.gitignore +++ b/.gitignore @@ -34,3 +34,5 @@ node_modules *.param *.bin .DS_Store +*.fst +*.arpa diff --git a/docs/source/model-export/export-ncnn.rst b/docs/source/model-export/export-ncnn.rst index 9eb5f85d2..634fb1e59 100644 --- a/docs/source/model-export/export-ncnn.rst +++ b/docs/source/model-export/export-ncnn.rst @@ -1,3 +1,5 @@ +.. _icefall_export_to_ncnn: + Export to ncnn ============== diff --git a/egs/librispeech/ASR/conformer_ctc/jit_pretrained_decode_with_H.py b/egs/librispeech/ASR/conformer_ctc/jit_pretrained_decode_with_H.py new file mode 100755 index 000000000..b52c7cfed --- /dev/null +++ b/egs/librispeech/ASR/conformer_ctc/jit_pretrained_decode_with_H.py @@ -0,0 +1,235 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang) + +""" +This file shows how to use a torchscript model for decoding with H +on CPU using OpenFST and decoders from kaldi. + +Usage: + + ./conformer_ctc/jit_pretrained_decode_with_H.py \ + --nn-model ./conformer_ctc/exp/cpu_jit.pt \ + --H ./data/lang_bpe_500/H.fst \ + --tokens ./data/lang_bpe_500/tokens.txt \ + ./download/LibriSpeech/test-clean/1089/134686/1089-134686-0002.flac \ + ./download/LibriSpeech/test-clean/1221/135766/1221-135766-0001.flac + +Note that to generate ./conformer_ctc/exp/cpu_jit.pt, +you can use ./export.py --jit 1 +""" + +import argparse +import logging +import math +from typing import Dict, List + +import kaldi_hmm_gmm +import kaldifeat +import kaldifst +import torch +import torchaudio +from kaldi_hmm_gmm import DecodableCtc, FasterDecoder, FasterDecoderOptions +from torch.nn.utils.rnn import pad_sequence + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--nn-model", + type=str, + required=True, + help="""Path to the torchscript model. + You can use ./conformer_ctc/export.py --jit 1 + to obtain it + """, + ) + + parser.add_argument( + "--tokens", + type=str, + required=True, + help="Path to tokens.txt", + ) + + parser.add_argument("--H", type=str, required=True, help="Path to H.fst") + + parser.add_argument( + "sound_files", + type=str, + nargs="+", + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. ", + ) + + return parser + + +def read_tokens(tokens_txt: str) -> Dict[int, str]: + id2token = dict() + with open(tokens_txt, encoding="utf-8") as f: + for line in f: + token, idx = line.strip().split() + id2token[int(idx)] = token + + return id2token + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + if sample_rate != expected_sample_rate: + wave = torchaudio.functional.resample( + wave, + orig_freq=sample_rate, + new_freq=expected_sample_rate, + ) + + # We use only the first channel + ans.append(wave[0].contiguous()) + return ans + + +def decode( + filename: str, + nnet_output: torch.Tensor, + H: kaldifst, + id2token: Dict[int, str], +) -> List[str]: + """ + Args: + filename: + Path to the filename for decoding. Used for debugging. + nnet_output: + A 2-D float32 tensor of shape (num_frames, vocab_size). It + contains output from log_softmax. + H: + The H graph. + id2token: + A map mapping token ID to token string. + Returns: + Return a list of decoded tokens. + """ + logging.info(f"{filename}, {nnet_output.shape}") + decodable = DecodableCtc(nnet_output.cpu()) + + decoder_opts = FasterDecoderOptions(max_active=3000) + decoder = FasterDecoder(H, decoder_opts) + decoder.decode(decodable) + + if not decoder.reached_final(): + print(f"failed to decode {filename}") + return [""] + + ok, best_path = decoder.get_best_path() + + ( + ok, + isymbols_out, + osymbols_out, + total_weight, + ) = kaldifst.get_linear_symbol_sequence(best_path) + if not ok: + print(f"failed to get linear symbol sequence for {filename}") + return [""] + + # tokens are incremented during graph construction + # so they need to be decremented + hyps = [id2token[i - 1] for i in osymbols_out] + # hyps = "".join(hyps).split("▁") + hyps = "".join(hyps).split("\u2581") # unicode codepoint of ▁ + + return hyps + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + + device = torch.device("cpu") + + logging.info(f"device: {device}") + + logging.info("Loading torchscript model") + model = torch.jit.load(args.nn_model) + model.eval() + model.to(device) + + logging.info(f"Loading H from {args.H}") + H = kaldifst.StdVectorFst.read(args.H) + + sample_rate = 16000 + + logging.info("Constructing Fbank computer") + opts = kaldifeat.FbankOptions() + opts.device = device + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = sample_rate + opts.mel_opts.num_bins = 80 + + fbank = kaldifeat.Fbank(opts) + + logging.info(f"Reading sound files: {args.sound_files}") + waves = read_sound_files( + filenames=args.sound_files, expected_sample_rate=sample_rate + ) + waves = [w.to(device) for w in waves] + + logging.info("Decoding started") + features = fbank(waves) + feature_lengths = [f.shape[0] for f in features] + feature_lengths = torch.tensor(feature_lengths) + + supervisions = dict() + supervisions["sequence_idx"] = torch.arange(len(features)) + supervisions["start_frame"] = torch.zeros(len(features)) + supervisions["num_frames"] = feature_lengths + + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) + + nnet_output, _, _ = model(features, supervisions) + feature_lengths = ((feature_lengths - 1) // 2 - 1) // 2 + + id2token = read_tokens(args.tokens) + + hyps = [] + for i in range(nnet_output.shape[0]): + hyp = decode( + filename=args.sound_files[i], + nnet_output=nnet_output[i, : feature_lengths[i]], + H=H, + id2token=id2token, + ) + hyps.append(hyp) + + s = "\n" + for filename, hyp in zip(args.sound_files, hyps): + words = " ".join(hyp) + s += f"{filename}:\n{words}\n\n" + logging.info(s) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/conformer_ctc/jit_pretrained_decode_with_HL.py b/egs/librispeech/ASR/conformer_ctc/jit_pretrained_decode_with_HL.py new file mode 100755 index 000000000..f0326ccdf --- /dev/null +++ b/egs/librispeech/ASR/conformer_ctc/jit_pretrained_decode_with_HL.py @@ -0,0 +1,232 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang) + +""" +This file shows how to use a torchscript model for decoding with H +on CPU using OpenFST and decoders from kaldi. + +Usage: + + ./conformer_ctc/jit_pretrained_decode_with_H.py \ + --nn-model ./conformer_ctc/exp/cpu_jit.pt \ + --HL ./data/lang_bpe_500/HL.fst \ + --words ./data/lang_bpe_500/words.txt \ + ./download/LibriSpeech/test-clean/1089/134686/1089-134686-0002.flac \ + ./download/LibriSpeech/test-clean/1221/135766/1221-135766-0001.flac + +Note that to generate ./conformer_ctc/exp/cpu_jit.pt, +you can use ./export.py --jit 1 +""" + +import argparse +import logging +import math +from typing import Dict, List + +import kaldi_hmm_gmm +import kaldifeat +import kaldifst +import torch +import torchaudio +from kaldi_hmm_gmm import DecodableCtc, FasterDecoder, FasterDecoderOptions +from torch.nn.utils.rnn import pad_sequence + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--nn-model", + type=str, + required=True, + help="""Path to the torchscript model. + You can use ./conformer_ctc/export.py --jit 1 + to obtain it + """, + ) + + parser.add_argument( + "--words", + type=str, + required=True, + help="Path to words.txt", + ) + + parser.add_argument("--HL", type=str, required=True, help="Path to HL.fst") + + parser.add_argument( + "sound_files", + type=str, + nargs="+", + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. ", + ) + + return parser + + +def read_words(words_txt: str) -> Dict[int, str]: + id2word = dict() + with open(words_txt, encoding="utf-8") as f: + for line in f: + word, idx = line.strip().split() + id2word[int(idx)] = word + + return id2word + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + if sample_rate != expected_sample_rate: + wave = torchaudio.functional.resample( + wave, + orig_freq=sample_rate, + new_freq=expected_sample_rate, + ) + + # We use only the first channel + ans.append(wave[0].contiguous()) + return ans + + +def decode( + filename: str, + nnet_output: torch.Tensor, + HL: kaldifst, + id2word: Dict[int, str], +) -> List[str]: + """ + Args: + filename: + Path to the filename for decoding. Used for debugging. + nnet_output: + A 2-D float32 tensor of shape (num_frames, vocab_size). It + contains output from log_softmax. + HL: + The HL graph. + word2token: + A map mapping token ID to word string. + Returns: + Return a list of decoded words. + """ + logging.info(f"{filename}, {nnet_output.shape}") + decodable = DecodableCtc(nnet_output.cpu()) + + decoder_opts = FasterDecoderOptions(max_active=3000) + decoder = FasterDecoder(HL, decoder_opts) + decoder.decode(decodable) + + if not decoder.reached_final(): + print(f"failed to decode {filename}") + return [""] + + ok, best_path = decoder.get_best_path() + + ( + ok, + isymbols_out, + osymbols_out, + total_weight, + ) = kaldifst.get_linear_symbol_sequence(best_path) + if not ok: + print(f"failed to get linear symbol sequence for {filename}") + return [""] + + # are shifted by 1 during graph construction + hyps = [id2word[i] for i in osymbols_out] + + return hyps + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + + device = torch.device("cpu") + + logging.info(f"device: {device}") + + logging.info("Loading torchscript model") + model = torch.jit.load(args.nn_model) + model.eval() + model.to(device) + + logging.info(f"Loading HL from {args.HL}") + HL = kaldifst.StdVectorFst.read(args.HL) + + sample_rate = 16000 + + logging.info("Constructing Fbank computer") + opts = kaldifeat.FbankOptions() + opts.device = device + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = sample_rate + opts.mel_opts.num_bins = 80 + + fbank = kaldifeat.Fbank(opts) + + logging.info(f"Reading sound files: {args.sound_files}") + waves = read_sound_files( + filenames=args.sound_files, expected_sample_rate=sample_rate + ) + waves = [w.to(device) for w in waves] + + logging.info("Decoding started") + features = fbank(waves) + feature_lengths = [f.shape[0] for f in features] + feature_lengths = torch.tensor(feature_lengths) + + supervisions = dict() + supervisions["sequence_idx"] = torch.arange(len(features)) + supervisions["start_frame"] = torch.zeros(len(features)) + supervisions["num_frames"] = feature_lengths + + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) + + nnet_output, _, _ = model(features, supervisions) + feature_lengths = ((feature_lengths - 1) // 2 - 1) // 2 + + id2word = read_words(args.words) + + hyps = [] + for i in range(nnet_output.shape[0]): + hyp = decode( + filename=args.sound_files[i], + nnet_output=nnet_output[i, : feature_lengths[i]], + HL=HL, + id2word=id2word, + ) + hyps.append(hyp) + + s = "\n" + for filename, hyp in zip(args.sound_files, hyps): + words = " ".join(hyp) + s += f"{filename}:\n{words}\n\n" + logging.info(s) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/local/prepare_lang_fst.py b/egs/librispeech/ASR/local/prepare_lang_fst.py new file mode 100755 index 000000000..e8401123f --- /dev/null +++ b/egs/librispeech/ASR/local/prepare_lang_fst.py @@ -0,0 +1,127 @@ +#!/usr/bin/env python3 + +# Copyright (c) 2023 Xiaomi Corporation (authors: Fangjun Kuang) + +""" +This script takes as input lang_dir containing lexicon_disambig.txt, +tokens.txt, and words.txt and generates the following files: + + - H.fst + - HL.fst + +Note that saved files are in OpenFst binary format. + +Usage: + +./local/prepare_lang_fst.py \ + --lang-dir ./data/lang_phone \ + --has-silence 1 + +Or + +./local/prepare_lang_fst.py \ + --lang-dir ./data/lang_bpe_500 +""" + +import argparse +import logging +from pathlib import Path + +import kaldifst + +from icefall.ctc import ( + Lexicon, + add_disambig_self_loops, + add_one, + build_standard_ctc_topo, + make_lexicon_fst_no_silence, + make_lexicon_fst_with_silence, +) +from icefall.utils import str2bool + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--lang-dir", + type=str, + help="""Input and output directory. + """, + ) + + parser.add_argument( + "--has-silence", + type=str2bool, + default=False, + help="True if the lexicon has silence.", + ) + + return parser.parse_args() + + +def main(): + args = get_args() + lang_dir = args.lang_dir + + lexicon = Lexicon(lang_dir) + + logging.info("Building standard CTC topology") + max_token_id = max(lexicon.tokens) + H = build_standard_ctc_topo(max_token_id=max_token_id) + + # We need to add one to all tokens since we want to use ID 0 + # for epsilon + add_one(H, treat_ilabel_zero_specially=False, update_olabel=True) + H.write(f"{lang_dir}/H.fst") + + logging.info("Building L") + # Now for HL + + if args.has_silence: + L = make_lexicon_fst_with_silence(lexicon, attach_symbol_table=False) + else: + L = make_lexicon_fst_no_silence(lexicon, attach_symbol_table=False) + + if args.has_silence: + # We also need to change the input labels of L + add_one(L, treat_ilabel_zero_specially=True, update_olabel=False) + else: + add_one(L, treat_ilabel_zero_specially=False, update_olabel=False) + + # Invoke add_disambig_self_loops() so that it eats the disambig symbols + # from L after composition + add_disambig_self_loops( + H, + start=lexicon.token2id["#0"] + 1, + end=lexicon.max_disambig_id + 1, + ) + with open("H_1.fst.txt", "w") as f: + print(H, file=f) + + kaldifst.arcsort(H, sort_type="olabel") + kaldifst.arcsort(L, sort_type="ilabel") + + logging.info("Building HL") + HL = kaldifst.compose(H, L) + kaldifst.determinize_star(HL) + + disambig0 = lexicon.token2id["#0"] + 1 + max_disambig = lexicon.max_disambig_id + 1 + for state in kaldifst.StateIterator(HL): + for arc in kaldifst.ArcIterator(HL, state): + # If treat_ilabel_zero_specially is False, we always change it + # Otherwise, we only change non-zero input labels + if disambig0 <= arc.ilabel <= max_disambig: + arc.ilabel = 0 + + # Note: We are not composing L with G, so there is no need to add + # self-loops to L to handle #0 + + HL.write(f"{lang_dir}/HL.fst") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/test_model.py b/egs/librispeech/ASR/lstm_transducer_stateless/test_model.py index 03dfe1997..91ef53e24 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless/test_model.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless/test_model.py @@ -57,8 +57,7 @@ def test_model(): convert_scaled_to_non_scaled(model, inplace=True) - if not os.path.exists(params.exp_dir): - os.path.mkdir(params.exp_dir) + params.exp_dir.mkdir(exist_ok=True) encoder_filename = params.exp_dir / "encoder_jit_trace.pt" export_encoder_model_jit_trace(model.encoder, encoder_filename) diff --git a/egs/librispeech/ASR/prepare.sh b/egs/librispeech/ASR/prepare.sh index 8ce1eb478..fca2c6cc4 100755 --- a/egs/librispeech/ASR/prepare.sh +++ b/egs/librispeech/ASR/prepare.sh @@ -242,6 +242,10 @@ if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then $lang_dir/L_disambig.pt \ $lang_dir/L_disambig.fst fi + + if [ ! -f $lang_dir/HL.fst ]; then + ./local/prepare_lang_fst.py --lang-dir $lang_dir + fi done fi diff --git a/egs/yesno/ASR/local/prepare_lang_fst.py b/egs/yesno/ASR/local/prepare_lang_fst.py new file mode 120000 index 000000000..c5787c534 --- /dev/null +++ b/egs/yesno/ASR/local/prepare_lang_fst.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/prepare_lang_fst.py \ No newline at end of file diff --git a/egs/yesno/ASR/prepare.sh b/egs/yesno/ASR/prepare.sh index d4ef8d601..41db0cf7c 100755 --- a/egs/yesno/ASR/prepare.sh +++ b/egs/yesno/ASR/prepare.sh @@ -60,6 +60,7 @@ if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then ) > $lang_dir/lexicon.txt ./local/prepare_lang.py + ./local/prepare_lang_fst.py --lang-dir ./data/lang_phone --has-silence 1 fi if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then diff --git a/egs/yesno/ASR/tdnn/jit_pretrained.py b/egs/yesno/ASR/tdnn/jit_pretrained.py index 84390fca5..7581ecb83 100755 --- a/egs/yesno/ASR/tdnn/jit_pretrained.py +++ b/egs/yesno/ASR/tdnn/jit_pretrained.py @@ -156,7 +156,6 @@ def main(): features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) - # Note: We don't use key padding mask for attention during decoding nnet_output = model(features) batch_size = nnet_output.shape[0] diff --git a/egs/yesno/ASR/tdnn/jit_pretrained_decode_with_H.py b/egs/yesno/ASR/tdnn/jit_pretrained_decode_with_H.py new file mode 100755 index 000000000..209ab477a --- /dev/null +++ b/egs/yesno/ASR/tdnn/jit_pretrained_decode_with_H.py @@ -0,0 +1,208 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang) + +""" +This file shows how to use a torchscript model for decoding with H +on CPU using OpenFST and decoders from kaldi. + +Usage: + + ./tdnn/jit_pretrained_decode_with_H.py \ + --nn-model ./tdnn/exp/cpu_jit.pt \ + --H ./data/lang_phone/H.fst \ + --tokens ./data/lang_phone/tokens.txt \ + ./download/waves_yesno/0_0_0_1_0_0_0_1.wav \ + ./download/waves_yesno/0_0_1_0_0_0_1_0.wav \ + ./download/waves_yesno/0_0_1_0_0_1_1_1.wav + +Note that to generate ./tdnn/exp/cpu_jit.pt, +you can use ./export.py --jit 1 +""" + +import argparse +import logging +import math +from typing import Dict, List + +import kaldifeat +import kaldifst +import torch +import torchaudio +from kaldi_hmm_gmm import DecodableCtc, FasterDecoder, FasterDecoderOptions +from torch.nn.utils.rnn import pad_sequence + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--nn-model", + type=str, + required=True, + help="""Path to the torchscript model. + You can use ./tdnn/export.py --jit 1 + to obtain it + """, + ) + + parser.add_argument( + "--tokens", + type=str, + required=True, + help="Path to tokens.txt", + ) + + parser.add_argument("--H", type=str, required=True, help="Path to H.fst") + + parser.add_argument( + "sound_files", + type=str, + nargs="+", + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. ", + ) + + return parser + + +def read_tokens(tokens_txt: str) -> Dict[int, str]: + id2token = dict() + with open(tokens_txt, encoding="utf-8") as f: + for line in f: + token, idx = line.strip().split() + id2token[int(idx)] = token + + return id2token + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + if sample_rate != expected_sample_rate: + wave = torchaudio.functional.resample( + wave, + orig_freq=sample_rate, + new_freq=expected_sample_rate, + ) + + # We use only the first channel + ans.append(wave[0].contiguous()) + return ans + + +def decode( + filename: str, + nnet_output: torch.Tensor, + H: kaldifst, + id2token: Dict[int, str], +) -> List[str]: + decodable = DecodableCtc(nnet_output) + decoder_opts = FasterDecoderOptions(max_active=3000) + decoder = FasterDecoder(H, decoder_opts) + decoder.decode(decodable) + + if not decoder.reached_final(): + print(f"failed to decode {filename}") + return [""] + + ok, best_path = decoder.get_best_path() + + ( + ok, + isymbols_out, + osymbols_out, + total_weight, + ) = kaldifst.get_linear_symbol_sequence(best_path) + if not ok: + print(f"failed to get linear symbol sequence for {filename}") + return [""] + + # are shifted by 1 during graph construction + hyps = [id2token[i - 1] for i in osymbols_out if id2token[i - 1] != "SIL"] + + return hyps + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + + device = torch.device("cpu") + + logging.info(f"device: {device}") + + logging.info("Loading torchscript model") + model = torch.jit.load(args.nn_model) + model.eval() + model.to(device) + + logging.info(f"Loading H from {args.H}") + H = kaldifst.StdVectorFst.read(args.H) + + sample_rate = 8000 + + logging.info("Constructing Fbank computer") + opts = kaldifeat.FbankOptions() + opts.device = device + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = sample_rate + opts.mel_opts.num_bins = 23 + + fbank = kaldifeat.Fbank(opts) + + logging.info(f"Reading sound files: {args.sound_files}") + waves = read_sound_files( + filenames=args.sound_files, expected_sample_rate=sample_rate + ) + waves = [w.to(device) for w in waves] + + logging.info("Decoding started") + features = fbank(waves) + + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) + + nnet_output = model(features) + + id2token = read_tokens(args.tokens) + + hyps = [] + for i in range(nnet_output.shape[0]): + hyp = decode( + filename=args.sound_files[0], + nnet_output=nnet_output[i], + H=H, + id2token=id2token, + ) + hyps.append(hyp) + + s = "\n" + for filename, hyp in zip(args.sound_files, hyps): + words = " ".join(hyp) + s += f"{filename}:\n{words}\n\n" + logging.info(s) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/yesno/ASR/tdnn/jit_pretrained_decode_with_HL.py b/egs/yesno/ASR/tdnn/jit_pretrained_decode_with_HL.py new file mode 100755 index 000000000..74864e17d --- /dev/null +++ b/egs/yesno/ASR/tdnn/jit_pretrained_decode_with_HL.py @@ -0,0 +1,207 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang) + +""" +This file shows how to use a torchscript model for decoding with HL +on CPU using OpenFST and decoders from kaldi. + +Usage: + + ./tdnn/jit_pretrained_decode_with_HL.py \ + --nn-model ./tdnn/exp/cpu_jit.pt \ + --HL ./data/lang_phone/HL.fst \ + --words ./data/lang_phone/words.txt \ + ./download/waves_yesno/0_0_0_1_0_0_0_1.wav \ + ./download/waves_yesno/0_0_1_0_0_0_1_0.wav \ + ./download/waves_yesno/0_0_1_0_0_1_1_1.wav + +Note that to generate ./tdnn/exp/cpu_jit.pt, +you can use ./export.py --jit 1 +""" + +import argparse +import logging +import math +from typing import Dict, List + +import kaldifeat +import kaldifst +import torch +import torchaudio +from kaldi_hmm_gmm import DecodableCtc, FasterDecoder, FasterDecoderOptions +from torch.nn.utils.rnn import pad_sequence + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--nn-model", + type=str, + required=True, + help="""Path to the torchscript model. + You can use ./tdnn/export.py --jit 1 + to obtain it + """, + ) + + parser.add_argument( + "--words", + type=str, + required=True, + help="Path to words.txt", + ) + + parser.add_argument("--HL", type=str, required=True, help="Path to HL.fst") + + parser.add_argument( + "sound_files", + type=str, + nargs="+", + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. ", + ) + + return parser + + +def read_words(words_txt: str) -> Dict[int, str]: + id2word = dict() + with open(words_txt, encoding="utf-8") as f: + for line in f: + word, idx = line.strip().split() + id2word[int(idx)] = word + + return id2word + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + if sample_rate != expected_sample_rate: + wave = torchaudio.functional.resample( + wave, + orig_freq=sample_rate, + new_freq=expected_sample_rate, + ) + + # We use only the first channel + ans.append(wave[0].contiguous()) + return ans + + +def decode( + filename: str, + nnet_output: torch.Tensor, + HL: kaldifst, + id2word: Dict[int, str], +) -> List[str]: + decodable = DecodableCtc(nnet_output) + decoder_opts = FasterDecoderOptions(max_active=3000) + decoder = FasterDecoder(HL, decoder_opts) + decoder.decode(decodable) + + if not decoder.reached_final(): + print(f"failed to decode {filename}") + return [""] + + ok, best_path = decoder.get_best_path() + + ( + ok, + isymbols_out, + osymbols_out, + total_weight, + ) = kaldifst.get_linear_symbol_sequence(best_path) + if not ok: + print(f"failed to get linear symbol sequence for {filename}") + return [""] + + hyps = [id2word[i] for i in osymbols_out if id2word[i] != ""] + + return hyps + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + + device = torch.device("cpu") + + logging.info(f"device: {device}") + + logging.info("Loading torchscript model") + model = torch.jit.load(args.nn_model) + model.eval() + model.to(device) + + logging.info(f"Loading HL from {args.HL}") + HL = kaldifst.StdVectorFst.read(args.HL) + + sample_rate = 8000 + + logging.info("Constructing Fbank computer") + opts = kaldifeat.FbankOptions() + opts.device = device + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = sample_rate + opts.mel_opts.num_bins = 23 + + fbank = kaldifeat.Fbank(opts) + + logging.info(f"Reading sound files: {args.sound_files}") + waves = read_sound_files( + filenames=args.sound_files, expected_sample_rate=sample_rate + ) + waves = [w.to(device) for w in waves] + + logging.info("Decoding started") + features = fbank(waves) + + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) + + nnet_output = model(features) + + id2word = read_words(args.words) + + hyps = [] + for i in range(nnet_output.shape[0]): + hyp = decode( + filename=args.sound_files[0], + nnet_output=nnet_output[i], + HL=HL, + id2word=id2word, + ) + hyps.append(hyp) + + s = "\n" + for filename, hyp in zip(args.sound_files, hyps): + words = " ".join(hyp) + s += f"{filename}:\n{words}\n\n" + logging.info(s) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/icefall/ctc/.gitignore b/icefall/ctc/.gitignore new file mode 100644 index 000000000..8154cb57f --- /dev/null +++ b/icefall/ctc/.gitignore @@ -0,0 +1,2 @@ +*.pdf +*.gv diff --git a/icefall/ctc/README.md b/icefall/ctc/README.md new file mode 100644 index 000000000..07b0ff8cd --- /dev/null +++ b/icefall/ctc/README.md @@ -0,0 +1,17 @@ +# Introduction + +This folder uses [kaldifst][kaldifst] for graph construction +and decoders from [kaldi-hmm-gmm][kaldi-hmm-gmm] for CTC decoding. + +It supports only `CPU`. + +You can use + +```bash +pip install kaldifst kaldi-hmm-gmm +``` +to install the dependencies. + +[kaldi-hmm-gmm]: https://github.com/csukuangfj/kaldi-hmm-gmm +[kaldifst]: https://github.com/k2-fsa/kaldifst +[k2]: https://github.com/k2-fsa/k2 diff --git a/icefall/ctc/__init__.py b/icefall/ctc/__init__.py new file mode 100644 index 000000000..b546b31af --- /dev/null +++ b/icefall/ctc/__init__.py @@ -0,0 +1,6 @@ +from .prepare_lang import ( + Lexicon, + make_lexicon_fst_no_silence, + make_lexicon_fst_with_silence, +) +from .topo import add_disambig_self_loops, add_one, build_standard_ctc_topo diff --git a/icefall/ctc/prepare_lang.py b/icefall/ctc/prepare_lang.py new file mode 100644 index 000000000..4801b1beb --- /dev/null +++ b/icefall/ctc/prepare_lang.py @@ -0,0 +1,334 @@ +# Copyright 2023 Xiaomi Corp. (author: Fangjun Kuang) + +""" +The lang_dir should contain the following files: + - "lexicon_disambig.txt" + - "tokens.txt" + - "words.txt" +""" + +import math +from collections import defaultdict +from pathlib import Path +from typing import List, Tuple + +import kaldifst +import re + + +class Lexicon: + """Once constructed it is immutable""" + + def __init__( + self, + lang_dir: str, + disambig_pattern: str = re.compile(r"^#\d+$"), + ): + """ + Args: + lang_dir: + The path to the lang directory. We expect that it contains the + following files: + - lexicon_disambig.txt + - tokens.txt + - words.txt + + The format of the above files is described below. + + (1) lexicon_disambig.txt + + Each line in the lexicon_disambig.txt has the following format: + + word token1 token2 ... tokenN + + That is, the first field is the word, the remaining fields are + pronunciations of this word. Fields are separated by space(s). + + (2) tokens.txt + + Each line in tokens.txt has two fields separated by space(s): + + token ID + + The first field is the token symbol and the second filed is the + integer ID of the token. + + (3) words.txt + + Each line in words.txt has two fields separated by space(s): + + word ID + + The first field is the word symbol and the second filed is the + integer ID of the word. + disambig_pattern: + It contains the pattern for disambiguation symbols. + """ + lang_dir = Path(lang_dir) + + lexicon_txt = lang_dir / "lexicon_disambig.txt" + tokens_txt = lang_dir / "tokens.txt" + words_txt = lang_dir / "words.txt" + + assert lexicon_txt.is_file(), lexicon_txt + assert tokens_txt.is_file(), tokens_txt + assert words_txt.is_file(), words_txt + + self._read_lexicon(lexicon_txt) + self._read_tokens(tokens_txt) + self._read_words(words_txt) + + self.disambig_pattern = disambig_pattern + + max_disambig_id = -1 + for s, i in self.token2id.items(): + if self.disambig_pattern.match(s) and i > max_disambig_id: + max_disambig_id = i + + self.max_disambig_id = max_disambig_id + + def _read_lexicon(self, lexicon_txt: str): + word2phones = defaultdict(list) + with open(lexicon_txt, encoding="utf-8") as f: + for line in f: + word_phones = line.strip().split() + assert len(word_phones) >= 2, (word_phones, line) + word = word_phones[0] + phones: str = " ".join(word_phones[1:]) + word2phones[word].append(phones) + # We use a list here since a word may have multiple + # pronunciations + + self.word2phones = word2phones + + def _read_tokens(self, tokens_txt): + token2id = dict() + id2token = dict() + with open(tokens_txt, encoding="utf-8") as f: + for line in f: + token_id = line.strip().split() + assert len(token_id) == 2, token_id + + token = token_id[0] + idx = int(token_id[1]) + + assert token not in token2id, f"Duplicate token {line}" + assert idx not in id2token, f"Duplicate ID {line}" + + token2id[token] = idx + id2token[idx] = token + self.token2id = token2id + self.id2token = id2token + + def _read_words(self, words_txt): + word2id = dict() + id2word = dict() + with open(words_txt, encoding="utf-8") as f: + for line in f: + word_id = line.strip().split() + assert len(word_id) == 2, word_id + + word = word_id[0] + idx = int(word_id[1]) + + assert word not in word2id, f"Duplicate token {line}" + assert idx not in id2word, f"Duplicate ID {line}" + + word2id[word] = idx + id2word[idx] = word + + self.word2id = word2id + self.id2word = id2word + + def __iter__(self) -> Tuple[str, List[str]]: + for word, phones_list in self.word2phones.items(): + for phones in phones_list: + yield word, phones + + def __str__(self): + return str(self.word2phones) + + @property + def tokens(self) -> List[int]: + """Return a list of token IDs excluding those from + disambiguation symbols. + + Caution: + 0 is not a token ID so it is excluded from the return value. + """ + ans = [] + for s in self.token2id: + if not self.disambig_pattern.match(s): + ans.append(self.token2id[s]) + if 0 in ans: + ans.remove(0) + ans.sort() + return ans + + +# See also +# http://vpanayotov.blogspot.com/2012/06/kaldi-decoding-graph-construction.html +def make_lexicon_fst_with_silence( + lexicon: Lexicon, + sil_prob: float = 0.5, + sil_phone: str = "SIL", + attach_symbol_table: bool = True, +) -> kaldifst.StdVectorFst: + phone2id = lexicon.token2id + word2id = lexicon.word2id + + assert sil_phone in phone2id + + assert sil_phone in phone2id, sil_phone + + sil_cost = -1 * math.log(sil_prob) + no_sil_cost = -1 * math.log(1.0 - sil_prob) + + fst = kaldifst.StdVectorFst() + + start_state = fst.add_state() + loop_state = fst.add_state() + sil_state = fst.add_state() + + fst.start = start_state + fst.set_final(state=loop_state, weight=0) + + fst.add_arc( + state=start_state, + arc=kaldifst.StdArc( + ilabel=0, + olabel=0, + weight=no_sil_cost, + nextstate=loop_state, + ), + ) + + fst.add_arc( + state=start_state, + arc=kaldifst.StdArc( + ilabel=0, + olabel=0, + weight=sil_cost, + nextstate=sil_state, + ), + ) + + fst.add_arc( + state=sil_state, + arc=kaldifst.StdArc( + ilabel=phone2id[sil_phone], + olabel=0, + weight=0, + nextstate=loop_state, + ), + ) + + for word, phones in lexicon: + phoneseq = phones.split() + pron_cost = 0 + cur_state = loop_state + + for i in range(len(phoneseq) - 1): + next_state = fst.add_state() + fst.add_arc( + state=cur_state, + arc=kaldifst.StdArc( + ilabel=phone2id[phoneseq[i]], + olabel=word2id[word] if i == 0 else 0, + weight=pron_cost if i == 0 else 0, + nextstate=next_state, + ), + ) + cur_state = next_state + + i = len(phoneseq) - 1 # note: i == -1 if phoneseq is empty. + + fst.add_arc( + state=cur_state, + arc=kaldifst.StdArc( + ilabel=phone2id[phoneseq[i]] if i >= 0 else 0, + olabel=word2id[word] if i <= 0 else 0, + weight=no_sil_cost + (pron_cost if i <= 0 else 0), + nextstate=loop_state, + ), + ) + + fst.add_arc( + state=cur_state, + arc=kaldifst.StdArc( + ilabel=phone2id[phoneseq[i]] if i >= 0 else 0, + olabel=word2id[word] if i <= 0 else 0, + weight=sil_cost + (pron_cost if i <= 0 else 0), + nextstate=sil_state, + ), + ) + + if attach_symbol_table: + isym = kaldifst.SymbolTable() + for p, i in phone2id.items(): + isym.add_symbol(symbol=p, key=i) + fst.input_symbols = isym + + osym = kaldifst.SymbolTable() + for w, i in word2id.items(): + osym.add_symbol(symbol=w, key=i) + fst.output_symbols = osym + + return fst + + +def make_lexicon_fst_no_silence( + lexicon: Lexicon, + attach_symbol_table: bool = True, +) -> kaldifst.StdVectorFst: + phone2id = lexicon.token2id + word2id = lexicon.word2id + + fst = kaldifst.StdVectorFst() + + start_state = fst.add_state() + fst.start = start_state + fst.set_final(state=start_state, weight=0) + + for word, phones in lexicon: + phoneseq = phones.split() + pron_cost = 0 + cur_state = start_state + + for i in range(len(phoneseq) - 1): + next_state = fst.add_state() + fst.add_arc( + state=cur_state, + arc=kaldifst.StdArc( + ilabel=phone2id[phoneseq[i]], + olabel=word2id[word] if i == 0 else 0, + weight=pron_cost if i == 0 else 0, + nextstate=next_state, + ), + ) + cur_state = next_state + + i = len(phoneseq) - 1 # note: i == -1 if phoneseq is empty. + + fst.add_arc( + state=cur_state, + arc=kaldifst.StdArc( + ilabel=phone2id[phoneseq[i]] if i >= 0 else 0, + olabel=word2id[word] if i <= 0 else 0, + weight=pron_cost if i <= 0 else 0, + nextstate=start_state, + ), + ) + + if attach_symbol_table: + isym = kaldifst.SymbolTable() + for p, i in phone2id.items(): + isym.add_symbol(symbol=p, key=i) + fst.input_symbols = isym + + osym = kaldifst.SymbolTable() + for w, i in word2id.items(): + osym.add_symbol(symbol=w, key=i) + fst.output_symbols = osym + + return fst diff --git a/icefall/ctc/test_ctc_topo.py b/icefall/ctc/test_ctc_topo.py new file mode 100755 index 000000000..4d4667209 --- /dev/null +++ b/icefall/ctc/test_ctc_topo.py @@ -0,0 +1,140 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang) + +from pathlib import Path + +import graphviz +import kaldifst +import sentencepiece as spm +from prepare_lang import ( + Lexicon, + make_lexicon_fst_no_silence, + make_lexicon_fst_with_silence, +) +from topo import add_disambig_self_loops, add_one, build_standard_ctc_topo + + +def test_yesno(): + lang_dir = "/Users/fangjun/open-source/icefall/egs/yesno/ASR/data/lang_phone" + if not Path(lang_dir).is_dir(): + print(f"{lang_dir} does not exist! Skip testing") + return + lexicon = Lexicon(lang_dir) + max_token_id = max(lexicon.tokens) + + H = build_standard_ctc_topo(max_token_id=max_token_id) + + isym = kaldifst.SymbolTable() + isym.add_symbol(symbol="", key=0) + for i in range(1, max_token_id + 1): + isym.add_symbol(symbol=lexicon.id2token[i], key=i) + + osym = kaldifst.SymbolTable() + osym.add_symbol(symbol="", key=0) + for i in range(1, max_token_id + 1): + osym.add_symbol(symbol=lexicon.id2token[i], key=i) + + H.input_symbols = isym + H.output_symbols = osym + + fst_dot = kaldifst.draw(H, acceptor=False, portrait=True) + source = graphviz.Source(fst_dot) + source.render(outfile="standard_ctc_topo_yesno.pdf") + # See the link below to visualize the above PDF + # https://t.ly/7uXZ9 + + # Now test HL + + # We need to add one to all tokens since we want to use ID 0 + # for epsilon + add_one(H, treat_ilabel_zero_specially=False, update_olabel=True) + + add_disambig_self_loops( + H, + start=lexicon.token2id["#0"] + 1, + end=lexicon.max_disambig_id, + ) + + fst_dot = kaldifst.draw(H, acceptor=False, portrait=True) + source = graphviz.Source(fst_dot) + source.render(outfile="standard_ctc_topo_disambig_yesno.pdf") + + L = make_lexicon_fst_with_silence(lexicon) + + # We also need to change the input labels of L + add_one(L, treat_ilabel_zero_specially=True, update_olabel=False) + + H.output_symbols = None + + kaldifst.arcsort(H, sort_type="olabel") + kaldifst.arcsort(L, sort_type="ilabel") + HL = kaldifst.compose(H, L) + + lexicon.id2token[0] = "" + lexicon.token2id[""] = 0 + + isym = kaldifst.SymbolTable() + isym.add_symbol(symbol="", key=0) + for i in range(0, lexicon.max_disambig_id + 1): + isym.add_symbol(symbol=lexicon.id2token[i], key=i + 1) + + osym = kaldifst.SymbolTable() + for i, word in lexicon.id2word.items(): + osym.add_symbol(symbol=word, key=i) + + HL.input_symbols = isym + HL.output_symbols = osym + + fst_dot = kaldifst.draw(HL, acceptor=False, portrait=True) + source = graphviz.Source(fst_dot) + source.render(outfile="HL_yesno.pdf") + + +def test_librispeech(): + lang_dir = ( + "/star-fj/fangjun/open-source/icefall-2/egs/librispeech/ASR/data/lang_bpe_500" + ) + + if not Path(lang_dir).is_dir(): + print(f"{lang_dir} does not exist! Skip testing") + return + + lexicon = Lexicon(lang_dir) + HL = kaldifst.StdVectorFst.read(lang_dir + "/HL.fst") + + sp = spm.SentencePieceProcessor() + sp.load(lang_dir + "/bpe.model") + + i = lexicon.word2id["HELLOA"] + k = lexicon.word2id["WORLD"] + print(i, k) + s = f""" + 0 1 {i} {i} + 1 2 {k} {k} + 2 + """ + fst = kaldifst.compile( + s=s, + acceptor=False, + ) + + L = make_lexicon_fst_no_silence(lexicon, attach_symbol_table=False) + kaldifst.arcsort(L, sort_type="olabel") + with open("L.fst.txt", "w") as f: + print(L, file=f) + + fst = kaldifst.compose(L, fst) + print(fst) + fst_dot = kaldifst.draw(fst, acceptor=False, portrait=True) + source = graphviz.Source(fst_dot) + source.render(outfile="a.pdf") + print(sp.encode(["HELLOA", "WORLD"])) + + +def main(): + test_yesno() + test_librispeech() + + +if __name__ == "__main__": + main() diff --git a/icefall/ctc/test_prepare_lang.py b/icefall/ctc/test_prepare_lang.py new file mode 100755 index 000000000..6c4b9e510 --- /dev/null +++ b/icefall/ctc/test_prepare_lang.py @@ -0,0 +1,43 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang) + +from pathlib import Path + +import graphviz +import kaldifst +from prepare_lang import Lexicon, make_lexicon_fst_with_silence + + +def test_yesno(): + lang_dir = "/Users/fangjun/open-source/icefall/egs/yesno/ASR/data/lang_phone" + if not Path(lang_dir).is_dir(): + print(f"{lang_dir} does not exist! Skip testing") + return + + lexicon = Lexicon(lang_dir) + + L = make_lexicon_fst_with_silence(lexicon) + + isym = kaldifst.SymbolTable() + for i, token in lexicon.id2token.items(): + isym.add_symbol(symbol=token, key=i) + + osym = kaldifst.SymbolTable() + for i, word in lexicon.id2word.items(): + osym.add_symbol(symbol=word, key=i) + + L.input_symbols = isym + L.output_symbols = osym + fst_dot = kaldifst.draw(L, acceptor=False, portrait=True) + source = graphviz.Source(fst_dot) + source.render(outfile="L_yesno.pdf") + # See the link below to visualize the above PDF + # https://t.ly/jMfXW + + +def main(): + test_yesno() + + +if __name__ == "__main__": + main() diff --git a/icefall/ctc/topo.py b/icefall/ctc/topo.py new file mode 100644 index 000000000..6a96dd038 --- /dev/null +++ b/icefall/ctc/topo.py @@ -0,0 +1,137 @@ +# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang) + +import kaldifst + + +# Note the name contains `standard`; it means there will be non-standard +# topologies. +def build_standard_ctc_topo(max_token_id: int) -> kaldifst.StdVectorFst: + """Build a standard CTC topology. + + Args: + Maximum valid token ID. We assume token IDs are contiguous + and starts from 0. In other words, the vocabulary size is + ``max_token_id + 1``. We assume the ID of the blank symbol is 0. + """ + # Token ID starts from 0 and there are as many states as the + # number of tokens. + # + # Note that epsilon is not a token and the token with ID 0 in tokens.txt + # is not an epsilon. It means input label 0 of the resulting FST does + # not represent an epsilon. + # + # You can use the function `add_one()` to modify the input/output labels + # of the resulting FST + + num_states = max_token_id + 1 + + # Step 1: Create as many states as the number of tokens. + # Each state is a final state + fst = kaldifst.StdVectorFst() + for i in range(num_states): + s = fst.add_state() + fst.set_final(state=s, weight=0) + + # Step 2: Set state 0 as the start state. + # We assume the ID of the blank symbol is 0. + fst.start = 0 + + # Step 3: Build a fully connected graph. + for i in range(num_states): + for k in range(num_states): + fst.add_arc( + state=i, + arc=kaldifst.StdArc( + ilabel=k, + olabel=k if i != k else 0, # if i==k, it is a self loop + weight=0, + nextstate=k, + ), + ) + # Please see ./test_ctc_topo.py if you want to know what the resulting + # FST looks like + + return fst + + +def add_one( + fst: kaldifst.StdVectorFst, + treat_ilabel_zero_specially: bool, + update_olabel: bool, +) -> None: + """Modify the input and output labels of the given FST in-place. + + Args: + fst: + The FST to be modified. It is changed in-place. + treat_ilabel_zero_specially: + If True, then every non-zero input label is increased by one and the + zero input label is not changed. + If False, then every input label is increased by one. + update_olabel: + If False, the output label is not changed. + If True, then every non-zero output label is increased by one. + In either case, output label with 0 is not changed. + """ + for state in kaldifst.StateIterator(fst): + for arc in kaldifst.ArcIterator(fst, state): + # If treat_ilabel_zero_specially is False, we always change it + # Otherwise, we only change non-zero input labels + if treat_ilabel_zero_specially is False or arc.ilabel != 0: + arc.ilabel += 1 + + if update_olabel and arc.olabel != 0: + arc.olabel += 1 + + if fst.input_symbols is not None: + input_symbols = kaldifst.SymbolTable() + input_symbols.add_symbol(symbol="", key=0) + + for i in range(0, fst.input_symbols.num_symbols()): + s = fst.input_symbols.find(i) + input_symbols.add_symbol(symbol=s, key=i + 1) + + fst.input_symbols = input_symbols + + if update_olabel and fst.output_symbols is not None: + output_symbols = kaldifst.SymbolTable() + output_symbols.add_symbol(symbol="", key=0) + + for i in range(0, fst.output_symbols.num_symbols()): + s = fst.output_symbols.find(i) + output_symbols.add_symbol(symbol=s, key=i + 1) + + fst.output_symbols = output_symbols + + +def add_disambig_self_loops(fst: kaldifst.StdVectorFst, start: int, end: int): + """Add self-loops to each state. + + For each disambig symbol, we add a self-loop with input label disambig_id + and output label diambig_id of that disambig symbol. + + Args: + fst: + It is changed in-place. + start: + The ID of #0 + end: + The ID of the last disambig symbol. For instance if there are 3 + disambig symbols ``#0``, ``#1``, and ``#2``, then ``end`` is the ID + of ``#2``. + """ + for state in kaldifst.StateIterator(fst): + for i in range(start, end + 1): + fst.add_arc( + state=state, + arc=kaldifst.StdArc( + ilabel=i, + olabel=i, + weight=0, + nextstate=state, + ), + ) + + if fst.output_symbols: + for i in range(start, end + 1): + fst.output_symbols.add_symbol(symbol=f"#{i-start}", key=i) diff --git a/requirements-ci.txt b/requirements-ci.txt index 652e2ab47..6f8739ce0 100644 --- a/requirements-ci.txt +++ b/requirements-ci.txt @@ -27,3 +27,4 @@ onnx onnxmltools onnxruntime kaldifst +kaldi-hmm-gmm diff --git a/requirements.txt b/requirements.txt index f0098c236..c031d683c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,7 @@ kaldifst kaldilm kaldialign +kaldi-hmm-gmm sentencepiece>=0.1.96 tensorboard typeguard